diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index f42af328e..7a7c6087d 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,4 +1,14 @@ -ARG IMAGE=ghcr.io/newrelic-experimental/pyenv-devcontainer:latest - # To target other architectures, change the --platform directive in the Dockerfile. -FROM --platform=linux/amd64 ${IMAGE} +ARG IMAGE_TAG=latest +FROM ghcr.io/newrelic/newrelic-python-agent-ci:${IMAGE_TAG} + +# Setup non-root user +USER root +ARG UID=1000 +ARG GID=$UID +ENV HOME /home/vscode +RUN mkdir -p ${HOME} && \ + groupadd --gid ${GID} vscode && \ + useradd --uid ${UID} --gid ${GID} --home ${HOME} vscode && \ + chown -R ${UID}:${GID} /home/vscode +USER ${UID}:${GID} diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 92a8cdee4..fbefff476 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -5,7 +5,7 @@ // To target other architectures, change the --platform directive in the Dockerfile. "dockerfile": "Dockerfile", "args": { - "IMAGE": "ghcr.io/newrelic-experimental/pyenv-devcontainer:latest" + "IMAGE_TAG": "latest" } }, "remoteUser": "vscode", diff --git a/.github/containers/Dockerfile b/.github/containers/Dockerfile index 2fbefb14a..d761b6f4a 100644 --- a/.github/containers/Dockerfile +++ b/.github/containers/Dockerfile @@ -23,6 +23,8 @@ RUN export DEBIAN_FRONTEND=noninteractive && \ build-essential \ curl \ expat \ + fish \ + fontconfig \ freetds-common \ freetds-dev \ gcc \ @@ -46,13 +48,16 @@ RUN export DEBIAN_FRONTEND=noninteractive && \ python2-dev \ python3-dev \ python3-pip \ + sudo \ tzdata \ unixodbc-dev \ unzip \ + vim \ wget \ zip \ zlib1g \ - zlib1g-dev && \ + zlib1g-dev \ + zsh && \ rm -rf /var/lib/apt/lists/* # Build librdkafka from source @@ -93,6 +98,10 @@ RUN echo 'eval "$(pyenv init -)"' >>$HOME/.bashrc && \ # Install Python ARG PYTHON_VERSIONS="3.10 3.9 3.8 3.7 3.11 2.7 pypy2.7-7.3.12 pypy3.8-7.3.11" COPY --chown=1000:1000 --chmod=+x ./install-python.sh /tmp/install-python.sh -COPY ./requirements.txt /requirements.txt RUN /tmp/install-python.sh && \ rm /tmp/install-python.sh + +# Install dependencies for main python installation +COPY ./requirements.txt /tmp/requirements.txt +RUN pyenv exec pip install --upgrade -r /tmp/requirements.txt && \ + rm /tmp/requirements.txt \ No newline at end of file diff --git a/.github/containers/Makefile b/.github/containers/Makefile index 35081f738..4c057813d 100644 --- a/.github/containers/Makefile +++ b/.github/containers/Makefile @@ -19,16 +19,16 @@ REPO_ROOT:=$(realpath $(MAKEFILE_DIR)../../) .PHONY: default default: test +# Perform a shortened build for testing .PHONY: build build: - @# Perform a shortened build for testing @docker build $(MAKEFILE_DIR) \ -t ghcr.io/newrelic/newrelic-python-agent-ci:local \ --build-arg='PYTHON_VERSIONS=3.10 2.7' +# Ensure python versions are usable .PHONY: test test: build - @# Ensure python versions are usable @docker run --rm ghcr.io/newrelic/python-agent-ci:local /bin/bash -c '\ python3.10 --version && \ python2.7 --version && \ diff --git a/.github/containers/requirements.txt b/.github/containers/requirements.txt index 27fa6624b..68bdfe4fe 100644 --- a/.github/containers/requirements.txt +++ b/.github/containers/requirements.txt @@ -1,5 +1,9 @@ +bandit +black +flake8 +isort pip setuptools -wheel +tox virtualenv<20.22.0 -tox \ No newline at end of file +wheel \ No newline at end of file diff --git a/.github/stale.yml b/.github/stale.yml index 9d84541db..39e994219 100644 --- a/.github/stale.yml +++ b/.github/stale.yml @@ -13,7 +13,7 @@ # limitations under the License. # # Number of days of inactivity before an issue becomes stale -daysUntilStale: 60 +daysUntilStale: 365 # Number of days of inactivity before a stale issue is closed # Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale. daysUntilClose: false diff --git a/.github/workflows/build-ci-image.yml b/.github/workflows/build-ci-image.yml index 5bd0e6f69..8bd904661 100644 --- a/.github/workflows/build-ci-image.yml +++ b/.github/workflows/build-ci-image.yml @@ -63,6 +63,6 @@ jobs: with: push: ${{ github.event_name != 'pull_request' }} context: .github/containers - platforms: linux/amd64 + platforms: ${{ (github.ref == 'refs/head/main') && 'linux/amd64,linux/arm64' || 'linux/amd64' }} tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 943958197..e3b264a9f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -46,6 +46,7 @@ jobs: - postgres - rabbitmq - redis + - rediscluster - solr steps: @@ -384,6 +385,105 @@ jobs: path: ./**/.coverage.* retention-days: 1 + rediscluster: + env: + TOTAL_GROUPS: 1 + + strategy: + fail-fast: false + matrix: + group-number: [1] + + runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway + timeout-minutes: 30 + + services: + redis1: + image: hmstepanek/redis-cluster-node:1.0.0 + ports: + - 6379:6379 + - 16379:16379 + options: >- + --add-host=host.docker.internal:host-gateway + + redis2: + image: hmstepanek/redis-cluster-node:1.0.0 + ports: + - 6380:6379 + - 16380:16379 + options: >- + --add-host=host.docker.internal:host-gateway + + redis3: + image: hmstepanek/redis-cluster-node:1.0.0 + ports: + - 6381:6379 + - 16381:16379 + options: >- + --add-host=host.docker.internal:host-gateway + + redis4: + image: hmstepanek/redis-cluster-node:1.0.0 + ports: + - 6382:6379 + - 16382:16379 + options: >- + --add-host=host.docker.internal:host-gateway + + redis5: + image: hmstepanek/redis-cluster-node:1.0.0 + ports: + - 6383:6379 + - 16383:16379 + options: >- + --add-host=host.docker.internal:host-gateway + + redis6: + image: hmstepanek/redis-cluster-node:1.0.0 + ports: + - 6384:6379 + - 16384:16379 + options: >- + --add-host=host.docker.internal:host-gateway + + cluster-setup: + image: hmstepanek/redis-cluster:1.0.0 + options: >- + --add-host=host.docker.internal:host-gateway + + steps: + - uses: actions/checkout@v3 + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin + + - name: Get Environments + id: get-envs + run: | + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT + env: + GROUP_NUMBER: ${{ matrix.group-number }} + + - name: Test + run: | + tox -vv -e ${{ steps.get-envs.outputs.envs }} -p auto + env: + TOX_PARALLEL_NO_SPINNER: 1 + PY_COLORS: 0 + + - name: Upload Coverage Artifacts + uses: actions/upload-artifact@v3 + with: + name: coverage-${{ github.job }}-${{ strategy.job-index }} + path: ./**/.coverage.* + retention-days: 1 + redis: env: TOTAL_GROUPS: 2 @@ -924,3 +1024,68 @@ jobs: name: coverage-${{ github.job }}-${{ strategy.job-index }} path: ./**/.coverage.* retention-days: 1 + + firestore: + env: + TOTAL_GROUPS: 1 + + strategy: + fail-fast: false + matrix: + group-number: [1] + + runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway + timeout-minutes: 30 + + services: + firestore: + # Image set here MUST be repeated down below in options. See comment below. + image: gcr.io/google.com/cloudsdktool/google-cloud-cli:437.0.1-emulators + ports: + - 8080:8080 + # Set health checks to wait 5 seconds in lieu of an actual healthcheck + options: >- + --health-cmd "echo success" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + --health-start-period 5s + gcr.io/google.com/cloudsdktool/google-cloud-cli:437.0.1-emulators /bin/bash -c "gcloud emulators firestore start --host-port=0.0.0.0:8080" || + # This is a very hacky solution. GitHub Actions doesn't provide APIs for setting commands on services, but allows adding arbitrary options. + # --entrypoint won't work as it only accepts an executable and not the [] syntax. + # Instead, we specify the image again the command afterwards like a call to docker create. The result is a few environment variables + # and the original command being appended to our hijacked docker create command. We can avoid any issues by adding || to prevent that + # from every being executed as bash commands. + + steps: + - uses: actions/checkout@v3 + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin + + - name: Get Environments + id: get-envs + run: | + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT + env: + GROUP_NUMBER: ${{ matrix.group-number }} + + - name: Test + run: | + tox -vv -e ${{ steps.get-envs.outputs.envs }} -p auto + env: + TOX_PARALLEL_NO_SPINNER: 1 + PY_COLORS: 0 + + - name: Upload Coverage Artifacts + uses: actions/upload-artifact@v3 + with: + name: coverage-${{ github.job }}-${{ strategy.job-index }} + path: ./**/.coverage.* + retention-days: 1 diff --git a/newrelic/api/database_trace.py b/newrelic/api/database_trace.py index 2bc497688..8990a1ef4 100644 --- a/newrelic/api/database_trace.py +++ b/newrelic/api/database_trace.py @@ -16,7 +16,7 @@ import logging from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.database_node import DatabaseNode from newrelic.core.stack_trace import current_stack @@ -44,11 +44,6 @@ def register_database_client( dbapi2_module._nr_explain_query = explain_query dbapi2_module._nr_explain_stmts = explain_stmts dbapi2_module._nr_instance_info = instance_info - dbapi2_module._nr_datastore_instance_feature_flag = False - - -def enable_datastore_instance_feature(dbapi2_module): - dbapi2_module._nr_datastore_instance_feature_flag = True class DatabaseTrace(TimeTrace): @@ -153,12 +148,7 @@ def finalize_data(self, transaction, exc=None, value=None, tb=None): if instance_enabled or db_name_enabled: - if ( - self.dbapi2_module - and self.connect_params - and self.dbapi2_module._nr_datastore_instance_feature_flag - and self.dbapi2_module._nr_instance_info is not None - ): + if self.dbapi2_module and self.connect_params and self.dbapi2_module._nr_instance_info is not None: instance_info = self.dbapi2_module._nr_instance_info(*self.connect_params) @@ -244,9 +234,9 @@ def create_node(self): ) -def DatabaseTraceWrapper(wrapped, sql, dbapi2_module=None): +def DatabaseTraceWrapper(wrapped, sql, dbapi2_module=None, async_wrapper=None): def _nr_database_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -273,9 +263,9 @@ def _nr_database_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_database_trace_wrapper_) -def database_trace(sql, dbapi2_module=None): - return functools.partial(DatabaseTraceWrapper, sql=sql, dbapi2_module=dbapi2_module) +def database_trace(sql, dbapi2_module=None, async_wrapper=None): + return functools.partial(DatabaseTraceWrapper, sql=sql, dbapi2_module=dbapi2_module, async_wrapper=async_wrapper) -def wrap_database_trace(module, object_path, sql, dbapi2_module=None): - wrap_object(module, object_path, DatabaseTraceWrapper, (sql, dbapi2_module)) +def wrap_database_trace(module, object_path, sql, dbapi2_module=None, async_wrapper=None): + wrap_object(module, object_path, DatabaseTraceWrapper, (sql, dbapi2_module, async_wrapper)) diff --git a/newrelic/api/datastore_trace.py b/newrelic/api/datastore_trace.py index fb40abcab..0401c79ea 100644 --- a/newrelic/api/datastore_trace.py +++ b/newrelic/api/datastore_trace.py @@ -15,7 +15,7 @@ import functools from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.datastore_node import DatastoreNode @@ -82,6 +82,9 @@ def __enter__(self): self.product = transaction._intern_string(self.product) self.target = transaction._intern_string(self.target) self.operation = transaction._intern_string(self.operation) + self.host = transaction._intern_string(self.host) + self.port_path_or_id = transaction._intern_string(self.port_path_or_id) + self.database_name = transaction._intern_string(self.database_name) datastore_tracer_settings = transaction.settings.datastore_tracer self.instance_reporting_enabled = datastore_tracer_settings.instance_reporting.enabled @@ -92,7 +95,14 @@ def __repr__(self): return "<%s object at 0x%x %s>" % ( self.__class__.__name__, id(self), - dict(product=self.product, target=self.target, operation=self.operation), + dict( + product=self.product, + target=self.target, + operation=self.operation, + host=self.host, + port_path_or_id=self.port_path_or_id, + database_name=self.database_name, + ), ) def finalize_data(self, transaction, exc=None, value=None, tb=None): @@ -125,7 +135,7 @@ def create_node(self): ) -def DatastoreTraceWrapper(wrapped, product, target, operation): +def DatastoreTraceWrapper(wrapped, product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None): """Wraps a method to time datastore queries. :param wrapped: The function to apply the trace to. @@ -140,6 +150,16 @@ def DatastoreTraceWrapper(wrapped, product, target, operation): or the name of any API function/method in the client library. :type operation: str or callable + :param host: The name of the server hosting the actual datastore. + :type host: str + :param port_path_or_id: The value passed in can represent either the port, + path, or id of the datastore being connected to. + :type port_path_or_id: str + :param database_name: The name of database where the current query is being + executed. + :type database_name: str + :param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper. + :type async_wrapper: callable or None :rtype: :class:`newrelic.common.object_wrapper.FunctionWrapper` This is typically used to wrap datastore queries such as calls to Redis or @@ -155,7 +175,7 @@ def DatastoreTraceWrapper(wrapped, product, target, operation): """ def _nr_datastore_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -187,7 +207,33 @@ def _nr_datastore_trace_wrapper_(wrapped, instance, args, kwargs): else: _operation = operation - trace = DatastoreTrace(_product, _target, _operation, parent=parent, source=wrapped) + if callable(host): + if instance is not None: + _host = host(instance, *args, **kwargs) + else: + _host = host(*args, **kwargs) + else: + _host = host + + if callable(port_path_or_id): + if instance is not None: + _port_path_or_id = port_path_or_id(instance, *args, **kwargs) + else: + _port_path_or_id = port_path_or_id(*args, **kwargs) + else: + _port_path_or_id = port_path_or_id + + if callable(database_name): + if instance is not None: + _database_name = database_name(instance, *args, **kwargs) + else: + _database_name = database_name(*args, **kwargs) + else: + _database_name = database_name + + trace = DatastoreTrace( + _product, _target, _operation, _host, _port_path_or_id, _database_name, parent=parent, source=wrapped + ) if wrapper: # pylint: disable=W0125,W0126 return wrapper(wrapped, trace)(*args, **kwargs) @@ -198,7 +244,7 @@ def _nr_datastore_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_datastore_trace_wrapper_) -def datastore_trace(product, target, operation): +def datastore_trace(product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None): """Decorator allows datastore query to be timed. :param product: The name of the vendor. @@ -211,6 +257,16 @@ def datastore_trace(product, target, operation): or the name of any API function/method in the client library. :type operation: str + :param host: The name of the server hosting the actual datastore. + :type host: str + :param port_path_or_id: The value passed in can represent either the port, + path, or id of the datastore being connected to. + :type port_path_or_id: str + :param database_name: The name of database where the current query is being + executed. + :type database_name: str + :param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper. + :type async_wrapper: callable or None This is typically used to decorate datastore queries such as calls to Redis or ElasticSearch. @@ -224,10 +280,21 @@ def datastore_trace(product, target, operation): ... time.sleep(*args, **kwargs) """ - return functools.partial(DatastoreTraceWrapper, product=product, target=target, operation=operation) - - -def wrap_datastore_trace(module, object_path, product, target, operation): + return functools.partial( + DatastoreTraceWrapper, + product=product, + target=target, + operation=operation, + host=host, + port_path_or_id=port_path_or_id, + database_name=database_name, + async_wrapper=async_wrapper, + ) + + +def wrap_datastore_trace( + module, object_path, product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None +): """Method applies custom timing to datastore query. :param module: Module containing the method to be instrumented. @@ -244,6 +311,16 @@ def wrap_datastore_trace(module, object_path, product, target, operation): or the name of any API function/method in the client library. :type operation: str + :param host: The name of the server hosting the actual datastore. + :type host: str + :param port_path_or_id: The value passed in can represent either the port, + path, or id of the datastore being connected to. + :type port_path_or_id: str + :param database_name: The name of database where the current query is being + executed. + :type database_name: str + :param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper. + :type async_wrapper: callable or None This is typically used to time database query method calls such as Redis GET. @@ -256,4 +333,6 @@ def wrap_datastore_trace(module, object_path, product, target, operation): ... 'sleep') """ - wrap_object(module, object_path, DatastoreTraceWrapper, (product, target, operation)) + wrap_object( + module, object_path, DatastoreTraceWrapper, (product, target, operation, host, port_path_or_id, database_name, async_wrapper) + ) diff --git a/newrelic/api/external_trace.py b/newrelic/api/external_trace.py index c43c560c6..2e147df45 100644 --- a/newrelic/api/external_trace.py +++ b/newrelic/api/external_trace.py @@ -16,7 +16,7 @@ from newrelic.api.cat_header_mixin import CatHeaderMixin from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.external_node import ExternalNode @@ -66,9 +66,9 @@ def create_node(self): ) -def ExternalTraceWrapper(wrapped, library, url, method=None): +def ExternalTraceWrapper(wrapped, library, url, method=None, async_wrapper=None): def dynamic_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -103,7 +103,7 @@ def dynamic_wrapper(wrapped, instance, args, kwargs): return wrapped(*args, **kwargs) def literal_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -125,9 +125,9 @@ def literal_wrapper(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, literal_wrapper) -def external_trace(library, url, method=None): - return functools.partial(ExternalTraceWrapper, library=library, url=url, method=method) +def external_trace(library, url, method=None, async_wrapper=None): + return functools.partial(ExternalTraceWrapper, library=library, url=url, method=method, async_wrapper=async_wrapper) -def wrap_external_trace(module, object_path, library, url, method=None): - wrap_object(module, object_path, ExternalTraceWrapper, (library, url, method)) +def wrap_external_trace(module, object_path, library, url, method=None, async_wrapper=None): + wrap_object(module, object_path, ExternalTraceWrapper, (library, url, method, async_wrapper)) diff --git a/newrelic/api/function_trace.py b/newrelic/api/function_trace.py index 474c1b226..85d7617b6 100644 --- a/newrelic/api/function_trace.py +++ b/newrelic/api/function_trace.py @@ -15,7 +15,7 @@ import functools from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.function_node import FunctionNode @@ -89,9 +89,9 @@ def create_node(self): ) -def FunctionTraceWrapper(wrapped, name=None, group=None, label=None, params=None, terminal=False, rollup=None): +def FunctionTraceWrapper(wrapped, name=None, group=None, label=None, params=None, terminal=False, rollup=None, async_wrapper=None): def dynamic_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -147,7 +147,7 @@ def dynamic_wrapper(wrapped, instance, args, kwargs): return wrapped(*args, **kwargs) def literal_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -171,13 +171,13 @@ def literal_wrapper(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, literal_wrapper) -def function_trace(name=None, group=None, label=None, params=None, terminal=False, rollup=None): +def function_trace(name=None, group=None, label=None, params=None, terminal=False, rollup=None, async_wrapper=None): return functools.partial( - FunctionTraceWrapper, name=name, group=group, label=label, params=params, terminal=terminal, rollup=rollup + FunctionTraceWrapper, name=name, group=group, label=label, params=params, terminal=terminal, rollup=rollup, async_wrapper=async_wrapper ) def wrap_function_trace( - module, object_path, name=None, group=None, label=None, params=None, terminal=False, rollup=None + module, object_path, name=None, group=None, label=None, params=None, terminal=False, rollup=None, async_wrapper=None ): - return wrap_object(module, object_path, FunctionTraceWrapper, (name, group, label, params, terminal, rollup)) + return wrap_object(module, object_path, FunctionTraceWrapper, (name, group, label, params, terminal, rollup, async_wrapper)) diff --git a/newrelic/api/graphql_trace.py b/newrelic/api/graphql_trace.py index 7a2c9ec02..e8803fa68 100644 --- a/newrelic/api/graphql_trace.py +++ b/newrelic/api/graphql_trace.py @@ -16,7 +16,7 @@ from newrelic.api.time_trace import TimeTrace, current_trace from newrelic.api.transaction import current_transaction -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.graphql_node import GraphQLOperationNode, GraphQLResolverNode @@ -109,9 +109,9 @@ def set_transaction_name(self, priority=None): transaction.set_transaction_name(name, "GraphQL", priority=priority) -def GraphQLOperationTraceWrapper(wrapped): +def GraphQLOperationTraceWrapper(wrapped, async_wrapper=None): def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -130,16 +130,16 @@ def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_graphql_trace_wrapper_) -def graphql_operation_trace(): - return functools.partial(GraphQLOperationTraceWrapper) +def graphql_operation_trace(async_wrapper=None): + return functools.partial(GraphQLOperationTraceWrapper, async_wrapper=async_wrapper) -def wrap_graphql_operation_trace(module, object_path): - wrap_object(module, object_path, GraphQLOperationTraceWrapper) +def wrap_graphql_operation_trace(module, object_path, async_wrapper=None): + wrap_object(module, object_path, GraphQLOperationTraceWrapper, (async_wrapper,)) class GraphQLResolverTrace(TimeTrace): - def __init__(self, field_name=None, **kwargs): + def __init__(self, field_name=None, field_parent_type=None, field_return_type=None, field_path=None, **kwargs): parent = kwargs.pop("parent", None) source = kwargs.pop("source", None) if kwargs: @@ -148,6 +148,9 @@ def __init__(self, field_name=None, **kwargs): super(GraphQLResolverTrace, self).__init__(parent=parent, source=source) self.field_name = field_name + self.field_parent_type = field_parent_type + self.field_return_type = field_return_type + self.field_path = field_path self._product = None def __repr__(self): @@ -175,6 +178,9 @@ def product(self): def finalize_data(self, *args, **kwargs): self._add_agent_attribute("graphql.field.name", self.field_name) + self._add_agent_attribute("graphql.field.parentType", self.field_parent_type) + self._add_agent_attribute("graphql.field.returnType", self.field_return_type) + self._add_agent_attribute("graphql.field.path", self.field_path) return super(GraphQLResolverTrace, self).finalize_data(*args, **kwargs) @@ -193,9 +199,9 @@ def create_node(self): ) -def GraphQLResolverTraceWrapper(wrapped): +def GraphQLResolverTraceWrapper(wrapped, async_wrapper=None): def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -214,9 +220,9 @@ def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_graphql_trace_wrapper_) -def graphql_resolver_trace(): - return functools.partial(GraphQLResolverTraceWrapper) +def graphql_resolver_trace(async_wrapper=None): + return functools.partial(GraphQLResolverTraceWrapper, async_wrapper=async_wrapper) -def wrap_graphql_resolver_trace(module, object_path): - wrap_object(module, object_path, GraphQLResolverTraceWrapper) +def wrap_graphql_resolver_trace(module, object_path, async_wrapper=None): + wrap_object(module, object_path, GraphQLResolverTraceWrapper, (async_wrapper,)) diff --git a/newrelic/api/memcache_trace.py b/newrelic/api/memcache_trace.py index 6657a9ce2..87f12f9fc 100644 --- a/newrelic/api/memcache_trace.py +++ b/newrelic/api/memcache_trace.py @@ -15,7 +15,7 @@ import functools from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.memcache_node import MemcacheNode @@ -51,9 +51,9 @@ def create_node(self): ) -def MemcacheTraceWrapper(wrapped, command): +def MemcacheTraceWrapper(wrapped, command, async_wrapper=None): def _nr_wrapper_memcache_trace_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -80,9 +80,9 @@ def _nr_wrapper_memcache_trace_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_wrapper_memcache_trace_) -def memcache_trace(command): - return functools.partial(MemcacheTraceWrapper, command=command) +def memcache_trace(command, async_wrapper=None): + return functools.partial(MemcacheTraceWrapper, command=command, async_wrapper=async_wrapper) -def wrap_memcache_trace(module, object_path, command): - wrap_object(module, object_path, MemcacheTraceWrapper, (command,)) +def wrap_memcache_trace(module, object_path, command, async_wrapper=None): + wrap_object(module, object_path, MemcacheTraceWrapper, (command, async_wrapper)) diff --git a/newrelic/api/message_trace.py b/newrelic/api/message_trace.py index be819d704..f564c41cb 100644 --- a/newrelic/api/message_trace.py +++ b/newrelic/api/message_trace.py @@ -16,7 +16,7 @@ from newrelic.api.cat_header_mixin import CatHeaderMixin from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.message_node import MessageNode @@ -91,9 +91,9 @@ def create_node(self): ) -def MessageTraceWrapper(wrapped, library, operation, destination_type, destination_name, params={}, terminal=True): +def MessageTraceWrapper(wrapped, library, operation, destination_type, destination_name, params={}, terminal=True, async_wrapper=None): def _nr_message_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -144,7 +144,7 @@ def _nr_message_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_message_trace_wrapper_) -def message_trace(library, operation, destination_type, destination_name, params={}, terminal=True): +def message_trace(library, operation, destination_type, destination_name, params={}, terminal=True, async_wrapper=None): return functools.partial( MessageTraceWrapper, library=library, @@ -153,10 +153,11 @@ def message_trace(library, operation, destination_type, destination_name, params destination_name=destination_name, params=params, terminal=terminal, + async_wrapper=async_wrapper, ) -def wrap_message_trace(module, object_path, library, operation, destination_type, destination_name, params={}, terminal=True): +def wrap_message_trace(module, object_path, library, operation, destination_type, destination_name, params={}, terminal=True, async_wrapper=None): wrap_object( - module, object_path, MessageTraceWrapper, (library, operation, destination_type, destination_name, params, terminal) + module, object_path, MessageTraceWrapper, (library, operation, destination_type, destination_name, params, terminal, async_wrapper) ) diff --git a/newrelic/common/agent_http.py b/newrelic/common/agent_http.py index 555816796..89876a60c 100644 --- a/newrelic/common/agent_http.py +++ b/newrelic/common/agent_http.py @@ -113,9 +113,7 @@ def _supportability_request(params, payload, body, compression_time): pass @classmethod - def log_request( - cls, fp, method, url, params, payload, headers, body=None, compression_time=None - ): + def log_request(cls, fp, method, url, params, payload, headers, body=None, compression_time=None): cls._supportability_request(params, payload, body, compression_time) if not fp: @@ -127,7 +125,8 @@ def log_request( cls.AUDIT_LOG_ID += 1 print( - "TIME: %r" % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), file=fp, + "TIME: %r" % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + file=fp, ) print(file=fp) print("ID: %r" % cls.AUDIT_LOG_ID, file=fp) @@ -179,9 +178,7 @@ def log_response(cls, fp, log_id, status, headers, data, connection="direct"): except Exception: result = data - print( - "TIME: %r" % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), file=fp - ) + print("TIME: %r" % time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), file=fp) print(file=fp) print("ID: %r" % log_id, file=fp) print(file=fp) @@ -220,9 +217,7 @@ def send_request( class HttpClient(BaseClient): CONNECTION_CLS = urllib3.HTTPSConnectionPool PREFIX_SCHEME = "https://" - BASE_HEADERS = urllib3.make_headers( - keep_alive=True, accept_encoding=True, user_agent=USER_AGENT - ) + BASE_HEADERS = urllib3.make_headers(keep_alive=True, accept_encoding=True, user_agent=USER_AGENT) def __init__( self, @@ -266,11 +261,9 @@ def __init__( # If there is no resolved cafile, assume the bundled certs are # required and report this condition as a supportability metric. - if not verify_path.cafile: + if not verify_path.cafile and not verify_path.capath: ca_bundle_path = certs.where() - internal_metric( - "Supportability/Python/Certificate/BundleRequired", 1 - ) + internal_metric("Supportability/Python/Certificate/BundleRequired", 1) if ca_bundle_path: if os.path.isdir(ca_bundle_path): @@ -282,11 +275,13 @@ def __init__( connection_kwargs["cert_reqs"] = "NONE" proxy = self._parse_proxy( - proxy_scheme, proxy_host, proxy_port, proxy_user, proxy_pass, - ) - proxy_headers = ( - proxy and proxy.auth and urllib3.make_headers(proxy_basic_auth=proxy.auth) + proxy_scheme, + proxy_host, + proxy_port, + proxy_user, + proxy_pass, ) + proxy_headers = proxy and proxy.auth and urllib3.make_headers(proxy_basic_auth=proxy.auth) if proxy: if self.CONNECTION_CLS.scheme == "https" and proxy.scheme != "https": @@ -346,15 +341,9 @@ def _connection(self): if self._connection_attr: return self._connection_attr - retries = urllib3.Retry( - total=False, connect=None, read=None, redirect=0, status=None - ) + retries = urllib3.Retry(total=False, connect=None, read=None, redirect=0, status=None) self._connection_attr = self.CONNECTION_CLS( - self._host, - self._port, - strict=True, - retries=retries, - **self._connection_kwargs + self._host, self._port, strict=True, retries=retries, **self._connection_kwargs ) return self._connection_attr @@ -377,9 +366,7 @@ def log_request( if not self._prefix: url = self.CONNECTION_CLS.scheme + "://" + self._host + url - return super(HttpClient, self).log_request( - fp, method, url, params, payload, headers, body, compression_time - ) + return super(HttpClient, self).log_request(fp, method, url, params, payload, headers, body, compression_time) @staticmethod def _compress(data, method="gzip", level=None): @@ -442,16 +429,16 @@ def send_request( try: response = self._connection.request_encode_url( - method, - path, - fields=params, - body=body, - headers=merged_headers, - **self._urlopen_kwargs + method, path, fields=params, body=body, headers=merged_headers, **self._urlopen_kwargs ) except urllib3.exceptions.HTTPError as e: self.log_response( - self._audit_log_fp, request_id, 0, None, None, connection, + self._audit_log_fp, + request_id, + 0, + None, + None, + connection, ) # All urllib3 HTTP errors should be treated as a network # interface exception. @@ -539,9 +526,7 @@ def _supportability_request(params, payload, body, compression_time): "Supportability/Python/Collector/%s/ZLIB/Bytes" % agent_method, len(body), ) - internal_metric( - "Supportability/Python/Collector/ZLIB/Bytes", len(body) - ) + internal_metric("Supportability/Python/Collector/ZLIB/Bytes", len(body)) internal_metric( "Supportability/Python/Collector/%s/ZLIB/Compress" % agent_method, compression_time, @@ -551,28 +536,21 @@ def _supportability_request(params, payload, body, compression_time): len(payload), ) # Top level metric to aggregate overall bytes being sent - internal_metric( - "Supportability/Python/Collector/Output/Bytes", len(payload) - ) + internal_metric("Supportability/Python/Collector/Output/Bytes", len(payload)) @staticmethod def _supportability_response(status, exc, connection="direct"): if exc or not 200 <= status < 300: internal_count_metric("Supportability/Python/Collector/Failures", 1) - internal_count_metric( - "Supportability/Python/Collector/Failures/%s" % connection, 1 - ) + internal_count_metric("Supportability/Python/Collector/Failures/%s" % connection, 1) if exc: internal_count_metric( - "Supportability/Python/Collector/Exception/" - "%s" % callable_name(exc), + "Supportability/Python/Collector/Exception/" "%s" % callable_name(exc), 1, ) else: - internal_count_metric( - "Supportability/Python/Collector/HTTPError/%d" % status, 1 - ) + internal_count_metric("Supportability/Python/Collector/HTTPError/%d" % status, 1) class ApplicationModeClient(SupportabilityMixin, HttpClient): @@ -581,33 +559,31 @@ class ApplicationModeClient(SupportabilityMixin, HttpClient): class DeveloperModeClient(SupportabilityMixin, BaseClient): RESPONSES = { - "preconnect": {u"redirect_host": u"fake-collector.newrelic.com"}, + "preconnect": {"redirect_host": "fake-collector.newrelic.com"}, "agent_settings": [], "connect": { - u"js_agent_loader": u"", - u"js_agent_file": u"fake-js-agent.newrelic.com/nr-0.min.js", - u"browser_key": u"1234567890", - u"browser_monitoring.loader_version": u"0", - u"beacon": u"fake-beacon.newrelic.com", - u"error_beacon": u"fake-jserror.newrelic.com", - u"apdex_t": 0.5, - u"encoding_key": u"1111111111111111111111111111111111111111", - u"entity_guid": u"DEVELOPERMODEENTITYGUID", - u"agent_run_id": u"1234567", - u"product_level": 50, - u"trusted_account_ids": [12345], - u"trusted_account_key": u"12345", - u"url_rules": [], - u"collect_errors": True, - u"account_id": u"12345", - u"cross_process_id": u"12345#67890", - u"messages": [ - {u"message": u"Reporting to fake collector", u"level": u"INFO"} - ], - u"sampling_rate": 0, - u"collect_traces": True, - u"collect_span_events": True, - u"data_report_period": 60, + "js_agent_loader": "", + "js_agent_file": "fake-js-agent.newrelic.com/nr-0.min.js", + "browser_key": "1234567890", + "browser_monitoring.loader_version": "0", + "beacon": "fake-beacon.newrelic.com", + "error_beacon": "fake-jserror.newrelic.com", + "apdex_t": 0.5, + "encoding_key": "1111111111111111111111111111111111111111", + "entity_guid": "DEVELOPERMODEENTITYGUID", + "agent_run_id": "1234567", + "product_level": 50, + "trusted_account_ids": [12345], + "trusted_account_key": "12345", + "url_rules": [], + "collect_errors": True, + "account_id": "12345", + "cross_process_id": "12345#67890", + "messages": [{"message": "Reporting to fake collector", "level": "INFO"}], + "sampling_rate": 0, + "collect_traces": True, + "collect_span_events": True, + "data_report_period": 60, }, "metric_data": None, "get_agent_commands": [], @@ -651,7 +627,11 @@ def send_request( payload = {"return_value": result} response_data = json_encode(payload).encode("utf-8") self.log_response( - self._audit_log_fp, request_id, 200, {}, response_data, + self._audit_log_fp, + request_id, + 200, + {}, + response_data, ) return 200, response_data diff --git a/newrelic/common/async_wrapper.py b/newrelic/common/async_wrapper.py index c5f95308d..2d3db2b4b 100644 --- a/newrelic/common/async_wrapper.py +++ b/newrelic/common/async_wrapper.py @@ -18,7 +18,9 @@ is_coroutine_callable, is_asyncio_coroutine, is_generator_function, + is_async_generator_function, ) +from newrelic.packages import six def evaluate_wrapper(wrapper_string, wrapped, trace): @@ -29,7 +31,6 @@ def evaluate_wrapper(wrapper_string, wrapped, trace): def coroutine_wrapper(wrapped, trace): - WRAPPER = textwrap.dedent(""" @functools.wraps(wrapped) async def wrapper(*args, **kwargs): @@ -61,29 +62,76 @@ def wrapper(*args, **kwargs): return wrapped -def generator_wrapper(wrapped, trace): - @functools.wraps(wrapped) - def wrapper(*args, **kwargs): - g = wrapped(*args, **kwargs) - value = None - with trace: - while True: +if six.PY3: + def generator_wrapper(wrapped, trace): + WRAPPER = textwrap.dedent(""" + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + with trace: + result = yield from wrapped(*args, **kwargs) + return result + """) + + try: + return evaluate_wrapper(WRAPPER, wrapped, trace) + except: + return wrapped +else: + def generator_wrapper(wrapped, trace): + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + g = wrapped(*args, **kwargs) + with trace: try: - yielded = g.send(value) + yielded = g.send(None) + while True: + try: + sent = yield yielded + except GeneratorExit as e: + g.close() + raise + except BaseException as e: + yielded = g.throw(e) + else: + yielded = g.send(sent) except StopIteration: - break + return + return wrapper - try: - value = yield yielded - except BaseException as e: - value = yield g.throw(type(e), e) - return wrapper +def async_generator_wrapper(wrapped, trace): + WRAPPER = textwrap.dedent(""" + @functools.wraps(wrapped) + async def wrapper(*args, **kwargs): + g = wrapped(*args, **kwargs) + with trace: + try: + yielded = await g.asend(None) + while True: + try: + sent = yield yielded + except GeneratorExit as e: + await g.aclose() + raise + except BaseException as e: + yielded = await g.athrow(e) + else: + yielded = await g.asend(sent) + except StopAsyncIteration: + return + """) + + try: + return evaluate_wrapper(WRAPPER, wrapped, trace) + except: + return wrapped def async_wrapper(wrapped): if is_coroutine_callable(wrapped): return coroutine_wrapper + elif is_async_generator_function(wrapped): + return async_generator_wrapper elif is_generator_function(wrapped): if is_asyncio_coroutine(wrapped): return awaitable_generator_wrapper diff --git a/newrelic/common/coroutine.py b/newrelic/common/coroutine.py index cf4c91f85..33a4922f5 100644 --- a/newrelic/common/coroutine.py +++ b/newrelic/common/coroutine.py @@ -43,3 +43,11 @@ def _iscoroutinefunction_tornado(fn): def is_coroutine_callable(wrapped): return is_coroutine_function(wrapped) or is_coroutine_function(getattr(wrapped, "__call__", None)) + + +if hasattr(inspect, 'isasyncgenfunction'): + def is_async_generator_function(wrapped): + return inspect.isasyncgenfunction(wrapped) +else: + def is_async_generator_function(wrapped): + return False diff --git a/newrelic/common/package_version_utils.py b/newrelic/common/package_version_utils.py index f3d334e2a..3152342b4 100644 --- a/newrelic/common/package_version_utils.py +++ b/newrelic/common/package_version_utils.py @@ -70,6 +70,23 @@ def int_or_str(value): def _get_package_version(name): module = sys.modules.get(name, None) version = None + + # importlib was introduced into the standard library starting in Python3.8. + if "importlib" in sys.modules and hasattr(sys.modules["importlib"], "metadata"): + try: + # In Python3.10+ packages_distribution can be checked for as well + if hasattr(sys.modules["importlib"].metadata, "packages_distributions"): # pylint: disable=E1101 + distributions = sys.modules["importlib"].metadata.packages_distributions() # pylint: disable=E1101 + distribution_name = distributions.get(name, name) + else: + distribution_name = name + + version = sys.modules["importlib"].metadata.version(distribution_name) # pylint: disable=E1101 + if version not in NULL_VERSIONS: + return version + except Exception: + pass + for attr in VERSION_ATTRS: try: version = getattr(module, attr, None) @@ -84,15 +101,6 @@ def _get_package_version(name): except Exception: pass - # importlib was introduced into the standard library starting in Python3.8. - if "importlib" in sys.modules and hasattr(sys.modules["importlib"], "metadata"): - try: - version = sys.modules["importlib"].metadata.version(name) # pylint: disable=E1101 - if version not in NULL_VERSIONS: - return version - except Exception: - pass - if "pkg_resources" in sys.modules: try: version = sys.modules["pkg_resources"].get_distribution(name).version diff --git a/newrelic/config.py b/newrelic/config.py index f1304d7b3..6816c43b5 100644 --- a/newrelic/config.py +++ b/newrelic/config.py @@ -2285,6 +2285,87 @@ def _process_module_builtin_defaults(): "instrument_graphql_validate", ) + _process_module_definition( + "google.cloud.firestore_v1.base_client", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_base_client", + ) + _process_module_definition( + "google.cloud.firestore_v1.client", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_client", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_client", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_client", + ) + _process_module_definition( + "google.cloud.firestore_v1.document", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_document", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_document", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_document", + ) + _process_module_definition( + "google.cloud.firestore_v1.collection", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_collection", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_collection", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_collection", + ) + _process_module_definition( + "google.cloud.firestore_v1.query", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_query", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_query", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_query", + ) + _process_module_definition( + "google.cloud.firestore_v1.aggregation", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_aggregation", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_aggregation", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_aggregation", + ) + _process_module_definition( + "google.cloud.firestore_v1.batch", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_batch", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_batch", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_batch", + ) + _process_module_definition( + "google.cloud.firestore_v1.bulk_batch", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_bulk_batch", + ) + _process_module_definition( + "google.cloud.firestore_v1.transaction", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_transaction", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_transaction", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_transaction", + ) + _process_module_definition( "ariadne.asgi", "newrelic.hooks.framework_ariadne", @@ -2389,6 +2470,11 @@ def _process_module_builtin_defaults(): "newrelic.hooks.logger_loguru", "instrument_loguru_logger", ) + _process_module_definition( + "structlog._base", + "newrelic.hooks.logger_structlog", + "instrument_structlog__base", + ) _process_module_definition( "paste.httpserver", @@ -2690,20 +2776,6 @@ def _process_module_builtin_defaults(): "aioredis.connection", "newrelic.hooks.datastore_aioredis", "instrument_aioredis_connection" ) - # Redis v4.2+ - _process_module_definition( - "redis.asyncio.client", "newrelic.hooks.datastore_redis", "instrument_asyncio_redis_client" - ) - - # Redis v4.2+ - _process_module_definition( - "redis.asyncio.commands", "newrelic.hooks.datastore_redis", "instrument_asyncio_redis_client" - ) - - _process_module_definition( - "redis.asyncio.connection", "newrelic.hooks.datastore_aioredis", "instrument_aioredis_connection" - ) - # v7 and below _process_module_definition( "elasticsearch.client", @@ -2860,6 +2932,21 @@ def _process_module_builtin_defaults(): "instrument_pymongo_collection", ) + # Redis v4.2+ + _process_module_definition( + "redis.asyncio.client", "newrelic.hooks.datastore_redis", "instrument_asyncio_redis_client" + ) + + # Redis v4.2+ + _process_module_definition( + "redis.asyncio.commands", "newrelic.hooks.datastore_redis", "instrument_asyncio_redis_client" + ) + + # Redis v4.2+ + _process_module_definition( + "redis.asyncio.connection", "newrelic.hooks.datastore_redis", "instrument_asyncio_redis_connection" + ) + _process_module_definition( "redis.connection", "newrelic.hooks.datastore_redis", @@ -2867,6 +2954,10 @@ def _process_module_builtin_defaults(): ) _process_module_definition("redis.client", "newrelic.hooks.datastore_redis", "instrument_redis_client") + _process_module_definition( + "redis.commands.cluster", "newrelic.hooks.datastore_redis", "instrument_redis_commands_cluster" + ) + _process_module_definition( "redis.commands.core", "newrelic.hooks.datastore_redis", "instrument_redis_commands_core" ) diff --git a/newrelic/core/rules_engine.py b/newrelic/core/rules_engine.py index fccc5e5e1..62ecce3fe 100644 --- a/newrelic/core/rules_engine.py +++ b/newrelic/core/rules_engine.py @@ -22,6 +22,27 @@ class NormalizationRule(_NormalizationRule): + def __new__( + cls, + match_expression="", + replacement="", + ignore=False, + eval_order=0, + terminate_chain=False, + each_segment=False, + replace_all=False, + ): + return _NormalizationRule.__new__( + cls, + match_expression=match_expression, + replacement=replacement, + ignore=ignore, + eval_order=eval_order, + terminate_chain=terminate_chain, + each_segment=each_segment, + replace_all=replace_all, + ) + def __init__(self, *args, **kwargs): self.match_expression_re = re.compile(self.match_expression, re.IGNORECASE) diff --git a/newrelic/core/stats_engine.py b/newrelic/core/stats_engine.py index 615f2b11a..ebebe7dbe 100644 --- a/newrelic/core/stats_engine.py +++ b/newrelic/core/stats_engine.py @@ -188,7 +188,6 @@ def merge_dimensional_metric(self, value): class CountStats(TimeStats): - def merge_stats(self, other): self[0] += other[0] @@ -241,6 +240,7 @@ def reset_metric_stats(self): """ self.__stats_table = {} + class DimensionalMetrics(object): """Nested dictionary table for collecting a set of metrics broken down by tags.""" @@ -294,7 +294,7 @@ def record_dimensional_metric(self, name, value, tags=None): return (name, tags) def metrics(self): - """Returns an iterator over the set of value metrics. + """Returns an iterator over the set of value metrics. The items returned are a dictionary of tags for each metric value. Metric values are each a tuple consisting of the metric name and accumulated stats for the metric. @@ -326,7 +326,7 @@ def __getitem__(self, key): def __str__(self): return str(self.__stats_table) - + def __repr__(self): return "%s(%s)" % (__class__.__name__, repr(self.__stats_table)) @@ -1284,7 +1284,11 @@ def metric_data(self, normalizer=None): if normalizer is not None: for key, value in six.iteritems(self.__stats_table): - key = (normalizer(key[0])[0], key[1]) + normalized_name, ignored = normalizer(key[0]) + if ignored: + continue + + key = (normalized_name, key[1]) stats = normalized_stats.get(key) if stats is None: normalized_stats[key] = copy.copy(value) diff --git a/newrelic/hooks/component_graphqlserver.py b/newrelic/hooks/component_graphqlserver.py index 29004c11f..ebc62a34d 100644 --- a/newrelic/hooks/component_graphqlserver.py +++ b/newrelic/hooks/component_graphqlserver.py @@ -1,19 +1,18 @@ -from newrelic.api.asgi_application import wrap_asgi_application from newrelic.api.error_trace import ErrorTrace from newrelic.api.graphql_trace import GraphQLOperationTrace from newrelic.api.transaction import current_transaction -from newrelic.api.transaction_name import TransactionNameWrapper from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import wrap_function_wrapper +from newrelic.common.package_version_utils import get_package_version from newrelic.core.graphql_utils import graphql_statement from newrelic.hooks.framework_graphql import ( - framework_version as graphql_framework_version, + GRAPHQL_VERSION, + ignore_graphql_duplicate_exception, ) -from newrelic.hooks.framework_graphql import ignore_graphql_duplicate_exception -def framework_details(): - import graphql_server - return ("GraphQLServer", getattr(graphql_server, "__version__", None)) +GRAPHQL_SERVER_VERSION = get_package_version("graphql-server") +graphql_server_major_version = int(GRAPHQL_SERVER_VERSION.split(".")[0]) + def bind_query(schema, params, *args, **kwargs): return getattr(params, "query", None) @@ -30,9 +29,8 @@ def wrap_get_response(wrapped, instance, args, kwargs): except TypeError: return wrapped(*args, **kwargs) - framework = framework_details() - transaction.add_framework_info(name=framework[0], version=framework[1]) - transaction.add_framework_info(name="GraphQL", version=graphql_framework_version()) + transaction.add_framework_info(name="GraphQLServer", version=GRAPHQL_SERVER_VERSION) + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) if hasattr(query, "body"): query = query.body @@ -45,5 +43,8 @@ def wrap_get_response(wrapped, instance, args, kwargs): with ErrorTrace(ignore=ignore_graphql_duplicate_exception): return wrapped(*args, **kwargs) + def instrument_graphqlserver(module): - wrap_function_wrapper(module, "get_response", wrap_get_response) + if graphql_server_major_version <= 2: + return + wrap_function_wrapper(module, "get_response", wrap_get_response) diff --git a/newrelic/hooks/database_asyncpg.py b/newrelic/hooks/database_asyncpg.py index 0d03e9139..d6ca62ef3 100644 --- a/newrelic/hooks/database_asyncpg.py +++ b/newrelic/hooks/database_asyncpg.py @@ -12,11 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from newrelic.api.database_trace import ( - DatabaseTrace, - enable_datastore_instance_feature, - register_database_client, -) +from newrelic.api.database_trace import DatabaseTrace, register_database_client from newrelic.api.datastore_trace import DatastoreTrace from newrelic.common.object_wrapper import ObjectProxy, wrap_function_wrapper @@ -43,7 +39,6 @@ def instance_info(cls, args, kwargs): quoting_style="single+dollar", instance_info=PostgresApi.instance_info, ) -enable_datastore_instance_feature(PostgresApi) class ProtocolProxy(ObjectProxy): @@ -94,9 +89,7 @@ async def query(self, query, *args, **kwargs): async def prepare(self, stmt_name, query, *args, **kwargs): with DatabaseTrace( - "PREPARE {stmt_name} FROM '{query}'".format( - stmt_name=stmt_name, query=query - ), + "PREPARE {stmt_name} FROM '{query}'".format(stmt_name=stmt_name, query=query), dbapi2_module=PostgresApi, connect_params=getattr(self, "_nr_connect_params", None), source=self.__wrapped__.prepare, @@ -131,9 +124,7 @@ def proxy_protocol(wrapped, instance, args, kwargs): def wrap_connect(wrapped, instance, args, kwargs): host = port = database_name = None if "addr" in kwargs: - host, port, database_name = PostgresApi._instance_info( - kwargs["addr"], None, kwargs.get("params") - ) + host, port, database_name = PostgresApi._instance_info(kwargs["addr"], None, kwargs.get("params")) with DatastoreTrace( PostgresApi._nr_database_product, diff --git a/newrelic/hooks/database_mysqldb.py b/newrelic/hooks/database_mysqldb.py index 31dd6bc19..c36d91d40 100644 --- a/newrelic/hooks/database_mysqldb.py +++ b/newrelic/hooks/database_mysqldb.py @@ -14,54 +14,69 @@ import os -from newrelic.api.database_trace import (enable_datastore_instance_feature, - DatabaseTrace, register_database_client) +from newrelic.api.database_trace import DatabaseTrace, register_database_client from newrelic.api.function_trace import FunctionTrace from newrelic.api.transaction import current_transaction from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import wrap_object +from newrelic.hooks.database_dbapi2 import ConnectionFactory as DBAPI2ConnectionFactory +from newrelic.hooks.database_dbapi2 import ConnectionWrapper as DBAPI2ConnectionWrapper -from newrelic.hooks.database_dbapi2 import (ConnectionWrapper as - DBAPI2ConnectionWrapper, ConnectionFactory as DBAPI2ConnectionFactory) class ConnectionWrapper(DBAPI2ConnectionWrapper): - def __enter__(self): transaction = current_transaction() name = callable_name(self.__wrapped__.__enter__) with FunctionTrace(name, source=self.__wrapped__.__enter__): - cursor = self.__wrapped__.__enter__() + cursor = self.__wrapped__.__enter__() # The __enter__() method of original connection object returns # a new cursor instance for use with 'as' assignment. We need # to wrap that in a cursor wrapper otherwise we will not track # any queries done via it. - return self.__cursor_wrapper__(cursor, self._nr_dbapi2_module, - self._nr_connect_params, None) + return self.__cursor_wrapper__(cursor, self._nr_dbapi2_module, self._nr_connect_params, None) def __exit__(self, exc, value, tb): transaction = current_transaction() name = callable_name(self.__wrapped__.__exit__) with FunctionTrace(name, source=self.__wrapped__.__exit__): if exc is None: - with DatabaseTrace('COMMIT', self._nr_dbapi2_module, self._nr_connect_params, source=self.__wrapped__.__exit__): + with DatabaseTrace( + "COMMIT", self._nr_dbapi2_module, self._nr_connect_params, source=self.__wrapped__.__exit__ + ): return self.__wrapped__.__exit__(exc, value, tb) else: - with DatabaseTrace('ROLLBACK', self._nr_dbapi2_module, self._nr_connect_params, source=self.__wrapped__.__exit__): + with DatabaseTrace( + "ROLLBACK", self._nr_dbapi2_module, self._nr_connect_params, source=self.__wrapped__.__exit__ + ): return self.__wrapped__.__exit__(exc, value, tb) + class ConnectionFactory(DBAPI2ConnectionFactory): __connection_wrapper__ = ConnectionWrapper + def instance_info(args, kwargs): - def _bind_params(host=None, user=None, passwd=None, db=None, port=None, - unix_socket=None, conv=None, connect_timeout=None, compress=None, - named_pipe=None, init_command=None, read_default_file=None, - read_default_group=None, *args, **kwargs): - return (host, port, db, unix_socket, - read_default_file, read_default_group) + def _bind_params( + host=None, + user=None, + passwd=None, + db=None, + port=None, + unix_socket=None, + conv=None, + connect_timeout=None, + compress=None, + named_pipe=None, + init_command=None, + read_default_file=None, + read_default_group=None, + *args, + **kwargs + ): + return (host, port, db, unix_socket, read_default_file, read_default_group) params = _bind_params(*args, **kwargs) host, port, db, unix_socket, read_default_file, read_default_group = params @@ -69,38 +84,38 @@ def _bind_params(host=None, user=None, passwd=None, db=None, port=None, port_path_or_id = None if read_default_file or read_default_group: - host = host or 'default' - port_path_or_id = 'unknown' + host = host or "default" + port_path_or_id = "unknown" elif not host: - host = 'localhost' + host = "localhost" - if host == 'localhost': + if host == "localhost": # precedence: explicit -> cnf (if used) -> env -> 'default' - port_path_or_id = (unix_socket or - port_path_or_id or - os.getenv('MYSQL_UNIX_PORT', 'default')) + port_path_or_id = unix_socket or port_path_or_id or os.getenv("MYSQL_UNIX_PORT", "default") elif explicit_host: # only reach here if host is explicitly passed in port = port and str(port) # precedence: explicit -> cnf (if used) -> env -> '3306' - port_path_or_id = (port or - port_path_or_id or - os.getenv('MYSQL_TCP_PORT', '3306')) + port_path_or_id = port or port_path_or_id or os.getenv("MYSQL_TCP_PORT", "3306") # There is no default database if omitted from the connect params # In this case, we should report unknown - db = db or 'unknown' + db = db or "unknown" return (host, port_path_or_id, db) -def instrument_mysqldb(module): - register_database_client(module, database_product='MySQL', - quoting_style='single+double', explain_query='explain', - explain_stmts=('select',), instance_info=instance_info) - enable_datastore_instance_feature(module) +def instrument_mysqldb(module): + register_database_client( + module, + database_product="MySQL", + quoting_style="single+double", + explain_query="explain", + explain_stmts=("select",), + instance_info=instance_info, + ) - wrap_object(module, 'connect', ConnectionFactory, (module,)) + wrap_object(module, "connect", ConnectionFactory, (module,)) # The connect() function is actually aliased with Connect() and # Connection, the later actually being the Connection type object. @@ -108,5 +123,5 @@ def instrument_mysqldb(module): # interferes with direct type usage. If people are using the # Connection object directly, they should really be using connect(). - if hasattr(module, 'Connect'): - wrap_object(module, 'Connect', ConnectionFactory, (module,)) + if hasattr(module, "Connect"): + wrap_object(module, "Connect", ConnectionFactory, (module,)) diff --git a/newrelic/hooks/database_psycopg2.py b/newrelic/hooks/database_psycopg2.py index 970909a33..bbed13184 100644 --- a/newrelic/hooks/database_psycopg2.py +++ b/newrelic/hooks/database_psycopg2.py @@ -15,17 +15,19 @@ import inspect import os -from newrelic.api.database_trace import (enable_datastore_instance_feature, - register_database_client, DatabaseTrace) +from newrelic.api.database_trace import DatabaseTrace, register_database_client from newrelic.api.function_trace import FunctionTrace from newrelic.api.transaction import current_transaction from newrelic.common.object_names import callable_name -from newrelic.common.object_wrapper import (wrap_object, ObjectProxy, - wrap_function_wrapper) - -from newrelic.hooks.database_dbapi2 import (ConnectionWrapper as - DBAPI2ConnectionWrapper, ConnectionFactory as DBAPI2ConnectionFactory, - CursorWrapper as DBAPI2CursorWrapper, DEFAULT) +from newrelic.common.object_wrapper import ( + ObjectProxy, + wrap_function_wrapper, + wrap_object, +) +from newrelic.hooks.database_dbapi2 import DEFAULT +from newrelic.hooks.database_dbapi2 import ConnectionFactory as DBAPI2ConnectionFactory +from newrelic.hooks.database_dbapi2 import ConnectionWrapper as DBAPI2ConnectionWrapper +from newrelic.hooks.database_dbapi2 import CursorWrapper as DBAPI2CursorWrapper try: from urllib import unquote @@ -43,33 +45,27 @@ # used. If the default connection and cursor are used without any unknown # arguments, we can safely drop all cursor parameters to generate explain # plans. Explain plans do not work with named cursors. -def _bind_connect( - dsn=None, connection_factory=None, cursor_factory=None, - *args, **kwargs): +def _bind_connect(dsn=None, connection_factory=None, cursor_factory=None, *args, **kwargs): return bool(connection_factory or cursor_factory) -def _bind_cursor( - name=None, cursor_factory=None, scrollable=None, - withhold=False, *args, **kwargs): +def _bind_cursor(name=None, cursor_factory=None, scrollable=None, withhold=False, *args, **kwargs): return bool(cursor_factory or args or kwargs) class CursorWrapper(DBAPI2CursorWrapper): - def execute(self, sql, parameters=DEFAULT, *args, **kwargs): - if hasattr(sql, 'as_string'): + if hasattr(sql, "as_string"): sql = sql.as_string(self) - return super(CursorWrapper, self).execute(sql, parameters, *args, - **kwargs) + return super(CursorWrapper, self).execute(sql, parameters, *args, **kwargs) def __enter__(self): self.__wrapped__.__enter__() return self def executemany(self, sql, seq_of_parameters): - if hasattr(sql, 'as_string'): + if hasattr(sql, "as_string"): sql = sql.as_string(self) return super(CursorWrapper, self).executemany(sql, seq_of_parameters) @@ -83,7 +79,7 @@ def __enter__(self): transaction = current_transaction() name = callable_name(self.__wrapped__.__enter__) with FunctionTrace(name, source=self.__wrapped__.__enter__): - self.__wrapped__.__enter__() + self.__wrapped__.__enter__() # Must return a reference to self as otherwise will be # returning the inner connection object. If 'as' is used @@ -98,19 +94,20 @@ def __exit__(self, exc, value, tb): name = callable_name(self.__wrapped__.__exit__) with FunctionTrace(name, source=self.__wrapped__.__exit__): if exc is None: - with DatabaseTrace('COMMIT', - self._nr_dbapi2_module, self._nr_connect_params, source=self.__wrapped__.__exit__): + with DatabaseTrace( + "COMMIT", self._nr_dbapi2_module, self._nr_connect_params, source=self.__wrapped__.__exit__ + ): return self.__wrapped__.__exit__(exc, value, tb) else: - with DatabaseTrace('ROLLBACK', - self._nr_dbapi2_module, self._nr_connect_params, source=self.__wrapped__.__exit__): + with DatabaseTrace( + "ROLLBACK", self._nr_dbapi2_module, self._nr_connect_params, source=self.__wrapped__.__exit__ + ): return self.__wrapped__.__exit__(exc, value, tb) # This connection wrapper does not save cursor parameters for explain plans. It # is only used for the default connection class. class ConnectionWrapper(ConnectionSaveParamsWrapper): - def cursor(self, *args, **kwargs): # If any unknown cursor params are detected or a cursor factory is # used, store params for explain plans later. @@ -119,9 +116,9 @@ def cursor(self, *args, **kwargs): else: cursor_params = None - return self.__cursor_wrapper__(self.__wrapped__.cursor( - *args, **kwargs), self._nr_dbapi2_module, - self._nr_connect_params, cursor_params) + return self.__cursor_wrapper__( + self.__wrapped__.cursor(*args, **kwargs), self._nr_dbapi2_module, self._nr_connect_params, cursor_params + ) class ConnectionFactory(DBAPI2ConnectionFactory): @@ -144,15 +141,13 @@ def instance_info(args, kwargs): def _parse_connect_params(args, kwargs): - def _bind_params(dsn=None, *args, **kwargs): return dsn dsn = _bind_params(*args, **kwargs) try: - if dsn and (dsn.startswith('postgres://') or - dsn.startswith('postgresql://')): + if dsn and (dsn.startswith("postgres://") or dsn.startswith("postgresql://")): # Parse dsn as URI # @@ -166,53 +161,52 @@ def _bind_params(dsn=None, *args, **kwargs): # ipv6 brackets [] are contained in the URI hostname # and should be removed - host = host and host.strip('[]') + host = host and host.strip("[]") port = parsed_uri.port db_name = parsed_uri.path - db_name = db_name and db_name.lstrip('/') + db_name = db_name and db_name.lstrip("/") db_name = db_name or None - query = parsed_uri.query or '' + query = parsed_uri.query or "" qp = dict(parse_qsl(query)) # Query parameters override hierarchical values in URI. - host = qp.get('host') or host or None - hostaddr = qp.get('hostaddr') - port = qp.get('port') or port - db_name = qp.get('dbname') or db_name + host = qp.get("host") or host or None + hostaddr = qp.get("hostaddr") + port = qp.get("port") or port + db_name = qp.get("dbname") or db_name elif dsn: # Parse dsn as a key-value connection string - kv = dict([pair.split('=', 2) for pair in dsn.split()]) - host = kv.get('host') - hostaddr = kv.get('hostaddr') - port = kv.get('port') - db_name = kv.get('dbname') + kv = dict([pair.split("=", 2) for pair in dsn.split()]) + host = kv.get("host") + hostaddr = kv.get("hostaddr") + port = kv.get("port") + db_name = kv.get("dbname") else: # No dsn, so get the instance info from keyword arguments. - host = kwargs.get('host') - hostaddr = kwargs.get('hostaddr') - port = kwargs.get('port') - db_name = kwargs.get('database') + host = kwargs.get("host") + hostaddr = kwargs.get("hostaddr") + port = kwargs.get("port") + db_name = kwargs.get("database") # Ensure non-None values are strings. - (host, hostaddr, port, db_name) = [str(s) if s is not None else s - for s in (host, hostaddr, port, db_name)] + (host, hostaddr, port, db_name) = [str(s) if s is not None else s for s in (host, hostaddr, port, db_name)] except Exception: - host = 'unknown' - hostaddr = 'unknown' - port = 'unknown' - db_name = 'unknown' + host = "unknown" + hostaddr = "unknown" + port = "unknown" + db_name = "unknown" return (host, hostaddr, port, db_name) @@ -221,37 +215,39 @@ def _add_defaults(parsed_host, parsed_hostaddr, parsed_port, parsed_database): # ENV variables set the default values - parsed_host = parsed_host or os.environ.get('PGHOST') - parsed_hostaddr = parsed_hostaddr or os.environ.get('PGHOSTADDR') - parsed_port = parsed_port or os.environ.get('PGPORT') - database = parsed_database or os.environ.get('PGDATABASE') or 'default' + parsed_host = parsed_host or os.environ.get("PGHOST") + parsed_hostaddr = parsed_hostaddr or os.environ.get("PGHOSTADDR") + parsed_port = parsed_port or os.environ.get("PGPORT") + database = parsed_database or os.environ.get("PGDATABASE") or "default" # If hostaddr is present, we use that, since host is used for auth only. parsed_host = parsed_hostaddr or parsed_host if parsed_host is None: - host = 'localhost' - port = 'default' - elif parsed_host.startswith('/'): - host = 'localhost' - port = '%s/.s.PGSQL.%s' % (parsed_host, parsed_port or '5432') + host = "localhost" + port = "default" + elif parsed_host.startswith("/"): + host = "localhost" + port = "%s/.s.PGSQL.%s" % (parsed_host, parsed_port or "5432") else: host = parsed_host - port = parsed_port or '5432' + port = parsed_port or "5432" return (host, port, database) def instrument_psycopg2(module): - register_database_client(module, database_product='Postgres', - quoting_style='single+dollar', explain_query='explain', - explain_stmts=('select', 'insert', 'update', 'delete'), - instance_info=instance_info) - - enable_datastore_instance_feature(module) + register_database_client( + module, + database_product="Postgres", + quoting_style="single+dollar", + explain_query="explain", + explain_stmts=("select", "insert", "update", "delete"), + instance_info=instance_info, + ) - wrap_object(module, 'connect', ConnectionFactory, (module,)) + wrap_object(module, "connect", ConnectionFactory, (module,)) def wrapper_psycopg2_register_type(wrapped, instance, args, kwargs): @@ -277,7 +273,7 @@ def _bind_params(context, *args, **kwargs): # Unwrap the context for string conversion since psycopg2 uses duck typing # and a TypeError will be raised if a wrapper is used. - if hasattr(context, '__wrapped__'): + if hasattr(context, "__wrapped__"): context = context.__wrapped__ return wrapped(context, *_args, **_kwargs) @@ -289,36 +285,31 @@ def _bind_params(context, *args, **kwargs): # In doing that we need to make sure it has not already been monkey # patched by checking to see if it is already an ObjectProxy. def instrument_psycopg2__psycopg2(module): - if hasattr(module, 'register_type'): + if hasattr(module, "register_type"): if not isinstance(module.register_type, ObjectProxy): - wrap_function_wrapper(module, 'register_type', - wrapper_psycopg2_register_type) + wrap_function_wrapper(module, "register_type", wrapper_psycopg2_register_type) def instrument_psycopg2_extensions(module): - if hasattr(module, 'register_type'): + if hasattr(module, "register_type"): if not isinstance(module.register_type, ObjectProxy): - wrap_function_wrapper(module, 'register_type', - wrapper_psycopg2_register_type) + wrap_function_wrapper(module, "register_type", wrapper_psycopg2_register_type) def instrument_psycopg2__json(module): - if hasattr(module, 'register_type'): + if hasattr(module, "register_type"): if not isinstance(module.register_type, ObjectProxy): - wrap_function_wrapper(module, 'register_type', - wrapper_psycopg2_register_type) + wrap_function_wrapper(module, "register_type", wrapper_psycopg2_register_type) def instrument_psycopg2__range(module): - if hasattr(module, 'register_type'): + if hasattr(module, "register_type"): if not isinstance(module.register_type, ObjectProxy): - wrap_function_wrapper(module, 'register_type', - wrapper_psycopg2_register_type) + wrap_function_wrapper(module, "register_type", wrapper_psycopg2_register_type) def instrument_psycopg2_sql(module): - if (hasattr(module, 'Composable') and - hasattr(module.Composable, 'as_string')): + if hasattr(module, "Composable") and hasattr(module.Composable, "as_string"): for name, cls in inspect.getmembers(module): if not inspect.isclass(cls): continue @@ -326,5 +317,4 @@ def instrument_psycopg2_sql(module): if not issubclass(cls, module.Composable): continue - wrap_function_wrapper(module, name + '.as_string', - wrapper_psycopg2_as_string) + wrap_function_wrapper(module, name + ".as_string", wrapper_psycopg2_as_string) diff --git a/newrelic/hooks/datastore_aioredis.py b/newrelic/hooks/datastore_aioredis.py index df79acd75..03c0f0900 100644 --- a/newrelic/hooks/datastore_aioredis.py +++ b/newrelic/hooks/datastore_aioredis.py @@ -139,7 +139,6 @@ async def wrap_Connection_send_command(wrapped, instance, args, kwargs): ): return await wrapped(*args, **kwargs) - # This wrapper is for versions of aioredis that are outside # New Relic's supportability window but will still work. New # Relic does not provide testing/support for this. In order to diff --git a/newrelic/hooks/datastore_firestore.py b/newrelic/hooks/datastore_firestore.py new file mode 100644 index 000000000..6d3196a7c --- /dev/null +++ b/newrelic/hooks/datastore_firestore.py @@ -0,0 +1,473 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from newrelic.api.datastore_trace import wrap_datastore_trace +from newrelic.api.function_trace import wrap_function_trace +from newrelic.common.async_wrapper import generator_wrapper, async_generator_wrapper + + +def _conn_str_to_host(getter): + """Safely transform a getter that can retrieve a connection string into the resulting host.""" + + def closure(obj, *args, **kwargs): + try: + return getter(obj, *args, **kwargs).split(":")[0] + except Exception: + return None + + return closure + + +def _conn_str_to_port(getter): + """Safely transform a getter that can retrieve a connection string into the resulting port.""" + + def closure(obj, *args, **kwargs): + try: + return getter(obj, *args, **kwargs).split(":")[1] + except Exception: + return None + + return closure + + +# Default Target ID and Instance Info +_get_object_id = lambda obj, *args, **kwargs: getattr(obj, "id", None) +_get_client_database_string = lambda obj, *args, **kwargs: getattr( + getattr(obj, "_client", None), "_database_string", None +) +_get_client_target = lambda obj, *args, **kwargs: obj._client._target +_get_client_target_host = _conn_str_to_host(_get_client_target) +_get_client_target_port = _conn_str_to_port(_get_client_target) + +# Client Instance Info +_get_database_string = lambda obj, *args, **kwargs: getattr(obj, "_database_string", None) +_get_target = lambda obj, *args, **kwargs: obj._target +_get_target_host = _conn_str_to_host(_get_target) +_get_target_port = _conn_str_to_port(_get_target) + +# Query Target ID +_get_parent_id = lambda obj, *args, **kwargs: getattr(getattr(obj, "_parent", None), "id", None) + +# AggregationQuery Target ID +_get_collection_ref_id = lambda obj, *args, **kwargs: getattr(getattr(obj, "_collection_ref", None), "id", None) + + +def instrument_google_cloud_firestore_v1_base_client(module): + rollup = ("Datastore/all", "Datastore/Firestore/all") + wrap_function_trace( + module, "BaseClient.__init__", name="%s:BaseClient.__init__" % module.__name__, terminal=True, rollup=rollup + ) + + +def instrument_google_cloud_firestore_v1_client(module): + if hasattr(module, "Client"): + class_ = module.Client + for method in ("collections", "get_all"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "Client.%s" % method, + operation=method, + product="Firestore", + target=None, + host=_get_target_host, + port_path_or_id=_get_target_port, + database_name=_get_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_client(module): + if hasattr(module, "AsyncClient"): + class_ = module.AsyncClient + for method in ("collections", "get_all"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncClient.%s" % method, + operation=method, + product="Firestore", + target=None, + host=_get_target_host, + port_path_or_id=_get_target_port, + database_name=_get_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_collection(module): + if hasattr(module, "CollectionReference"): + class_ = module.CollectionReference + for method in ("add", "get"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "CollectionReference.%s" % method, + product="Firestore", + target=_get_object_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream", "list_documents"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "CollectionReference.%s" % method, + operation=method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_collection(module): + if hasattr(module, "AsyncCollectionReference"): + class_ = module.AsyncCollectionReference + for method in ("add", "get"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncCollectionReference.%s" % method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + operation=method, + ) + + for method in ("stream", "list_documents"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncCollectionReference.%s" % method, + operation=method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_document(module): + if hasattr(module, "DocumentReference"): + class_ = module.DocumentReference + for method in ("create", "delete", "get", "set", "update"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "DocumentReference.%s" % method, + product="Firestore", + target=_get_object_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("collections",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "DocumentReference.%s" % method, + operation=method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_document(module): + if hasattr(module, "AsyncDocumentReference"): + class_ = module.AsyncDocumentReference + for method in ("create", "delete", "get", "set", "update"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncDocumentReference.%s" % method, + product="Firestore", + target=_get_object_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("collections",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncDocumentReference.%s" % method, + operation=method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_query(module): + if hasattr(module, "Query"): + class_ = module.Query + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "Query.%s" % method, + product="Firestore", + target=_get_parent_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "Query.%s" % method, + operation=method, + product="Firestore", + target=_get_parent_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + if hasattr(module, "CollectionGroup"): + class_ = module.CollectionGroup + for method in ("get_partitions",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "CollectionGroup.%s" % method, + operation=method, + product="Firestore", + target=_get_parent_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_query(module): + if hasattr(module, "AsyncQuery"): + class_ = module.AsyncQuery + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncQuery.%s" % method, + product="Firestore", + target=_get_parent_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncQuery.%s" % method, + operation=method, + product="Firestore", + target=_get_parent_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + if hasattr(module, "AsyncCollectionGroup"): + class_ = module.AsyncCollectionGroup + for method in ("get_partitions",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncCollectionGroup.%s" % method, + operation=method, + product="Firestore", + target=_get_parent_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_aggregation(module): + if hasattr(module, "AggregationQuery"): + class_ = module.AggregationQuery + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AggregationQuery.%s" % method, + product="Firestore", + target=_get_collection_ref_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AggregationQuery.%s" % method, + operation=method, + product="Firestore", + target=_get_collection_ref_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_aggregation(module): + if hasattr(module, "AsyncAggregationQuery"): + class_ = module.AsyncAggregationQuery + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncAggregationQuery.%s" % method, + product="Firestore", + target=_get_collection_ref_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncAggregationQuery.%s" % method, + operation=method, + product="Firestore", + target=_get_collection_ref_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_batch(module): + if hasattr(module, "WriteBatch"): + class_ = module.WriteBatch + for method in ("commit",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "WriteBatch.%s" % method, + product="Firestore", + target=None, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + +def instrument_google_cloud_firestore_v1_async_batch(module): + if hasattr(module, "AsyncWriteBatch"): + class_ = module.AsyncWriteBatch + for method in ("commit",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncWriteBatch.%s" % method, + product="Firestore", + target=None, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + +def instrument_google_cloud_firestore_v1_bulk_batch(module): + if hasattr(module, "BulkWriteBatch"): + class_ = module.BulkWriteBatch + for method in ("commit",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "BulkWriteBatch.%s" % method, + product="Firestore", + target=None, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + +def instrument_google_cloud_firestore_v1_transaction(module): + if hasattr(module, "Transaction"): + class_ = module.Transaction + for method in ("_commit", "_rollback"): + if hasattr(class_, method): + operation = method[1:] # Trim leading underscore + wrap_datastore_trace( + module, + "Transaction.%s" % method, + product="Firestore", + target=None, + operation=operation, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + +def instrument_google_cloud_firestore_v1_async_transaction(module): + if hasattr(module, "AsyncTransaction"): + class_ = module.AsyncTransaction + for method in ("_commit", "_rollback"): + if hasattr(class_, method): + operation = method[1:] # Trim leading underscore + wrap_datastore_trace( + module, + "AsyncTransaction.%s" % method, + product="Firestore", + target=None, + operation=operation, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) diff --git a/newrelic/hooks/datastore_redis.py b/newrelic/hooks/datastore_redis.py index 0f1c522b7..6ba192002 100644 --- a/newrelic/hooks/datastore_redis.py +++ b/newrelic/hooks/datastore_redis.py @@ -15,6 +15,7 @@ import re from newrelic.api.datastore_trace import DatastoreTrace +from newrelic.api.time_trace import current_trace from newrelic.api.transaction import current_transaction from newrelic.common.object_wrapper import function_wrapper, wrap_function_wrapper @@ -544,6 +545,59 @@ def _nr_wrapper_asyncio_Redis_method_(wrapped, instance, args, kwargs): wrap_function_wrapper(module, name, _nr_wrapper_asyncio_Redis_method_) +async def wrap_async_Connection_send_command(wrapped, instance, args, kwargs): + transaction = current_transaction() + if not transaction: + return await wrapped(*args, **kwargs) + + host, port_path_or_id, db = (None, None, None) + + try: + dt = transaction.settings.datastore_tracer + if dt.instance_reporting.enabled or dt.database_name_reporting.enabled: + conn_kwargs = _conn_attrs_to_dict(instance) + host, port_path_or_id, db = _instance_info(conn_kwargs) + except Exception: + pass + + # Older Redis clients would when sending multi part commands pass + # them in as separate arguments to send_command(). Need to therefore + # detect those and grab the next argument from the set of arguments. + + operation = args[0].strip().lower() + + # If it's not a multi part command, there's no need to trace it, so + # we can return early. + + if ( + operation.split()[0] not in _redis_multipart_commands + ): # Set the datastore info on the DatastoreTrace containing this function call. + trace = current_trace() + + # Find DatastoreTrace no matter how many other traces are inbetween + while trace is not None and not isinstance(trace, DatastoreTrace): + trace = getattr(trace, "parent", None) + + if trace is not None: + trace.host = host + trace.port_path_or_id = port_path_or_id + trace.database_name = db + + return await wrapped(*args, **kwargs) + + # Convert multi args to single arg string + + if operation in _redis_multipart_commands and len(args) > 1: + operation = "%s %s" % (operation, args[1].strip().lower()) + + operation = _redis_operation_re.sub("_", operation) + + with DatastoreTrace( + product="Redis", target=None, operation=operation, host=host, port_path_or_id=port_path_or_id, database_name=db + ): + return await wrapped(*args, **kwargs) + + def _nr_Connection_send_command_wrapper_(wrapped, instance, args, kwargs): transaction = current_transaction() @@ -658,4 +712,12 @@ def _instrument_redis_commands_module(module, class_name): def instrument_redis_connection(module): - wrap_function_wrapper(module, "Connection.send_command", _nr_Connection_send_command_wrapper_) + if hasattr(module, "Connection"): + if hasattr(module.Connection, "send_command"): + wrap_function_wrapper(module, "Connection.send_command", _nr_Connection_send_command_wrapper_) + + +def instrument_asyncio_redis_connection(module): + if hasattr(module, "Connection"): + if hasattr(module.Connection, "send_command"): + wrap_function_wrapper(module, "Connection.send_command", wrap_async_Connection_send_command) diff --git a/newrelic/hooks/framework_ariadne.py b/newrelic/hooks/framework_ariadne.py index 498c662c4..4927abe0b 100644 --- a/newrelic/hooks/framework_ariadne.py +++ b/newrelic/hooks/framework_ariadne.py @@ -21,17 +21,12 @@ from newrelic.api.wsgi_application import wrap_wsgi_application from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import wrap_function_wrapper +from newrelic.common.package_version_utils import get_package_version from newrelic.core.graphql_utils import graphql_statement -from newrelic.hooks.framework_graphql import ( - framework_version as graphql_framework_version, -) -from newrelic.hooks.framework_graphql import ignore_graphql_duplicate_exception +from newrelic.hooks.framework_graphql import GRAPHQL_VERSION, ignore_graphql_duplicate_exception - -def framework_details(): - import ariadne - - return ("Ariadne", getattr(ariadne, "__version__", None)) +ARIADNE_VERSION = get_package_version("ariadne") +ariadne_version_tuple = tuple(map(int, ARIADNE_VERSION.split("."))) def bind_graphql(schema, data, *args, **kwargs): @@ -49,9 +44,8 @@ def wrap_graphql_sync(wrapped, instance, args, kwargs): except TypeError: return wrapped(*args, **kwargs) - framework = framework_details() - transaction.add_framework_info(name=framework[0], version=framework[1]) # No version info available on ariadne - transaction.add_framework_info(name="GraphQL", version=graphql_framework_version()) + transaction.add_framework_info(name="Ariadne", version=ARIADNE_VERSION) + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) query = data["query"] if hasattr(query, "body"): @@ -83,9 +77,8 @@ async def wrap_graphql(wrapped, instance, args, kwargs): result = await result return result - framework = framework_details() - transaction.add_framework_info(name=framework[0], version=framework[1]) # No version info available on ariadne - transaction.add_framework_info(name="GraphQL", version=graphql_framework_version()) + transaction.add_framework_info(name="Ariadne", version=ARIADNE_VERSION) + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) query = data["query"] if hasattr(query, "body"): @@ -104,6 +97,9 @@ async def wrap_graphql(wrapped, instance, args, kwargs): def instrument_ariadne_execute(module): + # v0.9.0 is the version where ariadne started using graphql-core v3 + if ariadne_version_tuple < (0, 9): + return if hasattr(module, "graphql"): wrap_function_wrapper(module, "graphql", wrap_graphql) @@ -112,10 +108,14 @@ def instrument_ariadne_execute(module): def instrument_ariadne_asgi(module): + if ariadne_version_tuple < (0, 9): + return if hasattr(module, "GraphQL"): - wrap_asgi_application(module, "GraphQL.__call__", framework=framework_details()) + wrap_asgi_application(module, "GraphQL.__call__", framework=("Ariadne", ARIADNE_VERSION)) def instrument_ariadne_wsgi(module): + if ariadne_version_tuple < (0, 9): + return if hasattr(module, "GraphQL"): - wrap_wsgi_application(module, "GraphQL.__call__", framework=framework_details()) + wrap_wsgi_application(module, "GraphQL.__call__", framework=("Ariadne", ARIADNE_VERSION)) diff --git a/newrelic/hooks/framework_django.py b/newrelic/hooks/framework_django.py index 005f28279..3d9f448cc 100644 --- a/newrelic/hooks/framework_django.py +++ b/newrelic/hooks/framework_django.py @@ -12,48 +12,60 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools +import logging import sys import threading -import logging -import functools - -from newrelic.packages import six from newrelic.api.application import register_application from newrelic.api.background_task import BackgroundTaskWrapper from newrelic.api.error_trace import wrap_error_trace -from newrelic.api.function_trace import (FunctionTrace, wrap_function_trace, - FunctionTraceWrapper) +from newrelic.api.function_trace import ( + FunctionTrace, + FunctionTraceWrapper, + wrap_function_trace, +) from newrelic.api.html_insertion import insert_html_snippet -from newrelic.api.transaction import current_transaction from newrelic.api.time_trace import notice_error +from newrelic.api.transaction import current_transaction from newrelic.api.transaction_name import wrap_transaction_name from newrelic.api.wsgi_application import WSGIApplicationWrapper - -from newrelic.common.object_wrapper import (FunctionWrapper, wrap_in_function, - wrap_post_function, wrap_function_wrapper, function_wrapper) +from newrelic.common.coroutine import is_asyncio_coroutine, is_coroutine_function from newrelic.common.object_names import callable_name +from newrelic.common.object_wrapper import ( + FunctionWrapper, + function_wrapper, + wrap_function_wrapper, + wrap_in_function, + wrap_post_function, +) from newrelic.config import extra_settings from newrelic.core.config import global_settings -from newrelic.common.coroutine import is_coroutine_function, is_asyncio_coroutine +from newrelic.packages import six if six.PY3: from newrelic.hooks.framework_django_py3 import ( - _nr_wrapper_BaseHandler_get_response_async_, _nr_wrap_converted_middleware_async_, + _nr_wrapper_BaseHandler_get_response_async_, ) _logger = logging.getLogger(__name__) _boolean_states = { - '1': True, 'yes': True, 'true': True, 'on': True, - '0': False, 'no': False, 'false': False, 'off': False + "1": True, + "yes": True, + "true": True, + "on": True, + "0": False, + "no": False, + "false": False, + "off": False, } def _setting_boolean(value): if value.lower() not in _boolean_states: - raise ValueError('Not a boolean: %s' % value) + raise ValueError("Not a boolean: %s" % value) return _boolean_states[value.lower()] @@ -62,21 +74,20 @@ def _setting_set(value): _settings_types = { - 'browser_monitoring.auto_instrument': _setting_boolean, - 'instrumentation.templates.inclusion_tag': _setting_set, - 'instrumentation.background_task.startup_timeout': float, - 'instrumentation.scripts.django_admin': _setting_set, + "browser_monitoring.auto_instrument": _setting_boolean, + "instrumentation.templates.inclusion_tag": _setting_set, + "instrumentation.background_task.startup_timeout": float, + "instrumentation.scripts.django_admin": _setting_set, } _settings_defaults = { - 'browser_monitoring.auto_instrument': True, - 'instrumentation.templates.inclusion_tag': set(), - 'instrumentation.background_task.startup_timeout': 10.0, - 'instrumentation.scripts.django_admin': set(), + "browser_monitoring.auto_instrument": True, + "instrumentation.templates.inclusion_tag": set(), + "instrumentation.background_task.startup_timeout": 10.0, + "instrumentation.scripts.django_admin": set(), } -django_settings = extra_settings('import-hook:django', - types=_settings_types, defaults=_settings_defaults) +django_settings = extra_settings("import-hook:django", types=_settings_types, defaults=_settings_defaults) def should_add_browser_timing(response, transaction): @@ -92,7 +103,7 @@ def should_add_browser_timing(response, transaction): # do RUM insertion, need to move to a WSGI middleware and # deal with how to update the content length. - if hasattr(response, 'streaming_content'): + if hasattr(response, "streaming_content"): return False # Need to be running within a valid web transaction. @@ -121,21 +132,21 @@ def should_add_browser_timing(response, transaction): # a user may want to also perform insertion for # 'application/xhtml+xml'. - ctype = response.get('Content-Type', '').lower().split(';')[0] + ctype = response.get("Content-Type", "").lower().split(";")[0] if ctype not in transaction.settings.browser_monitoring.content_type: return False # Don't risk it if content encoding already set. - if response.has_header('Content-Encoding'): + if response.has_header("Content-Encoding"): return False # Don't risk it if content is actually within an attachment. - cdisposition = response.get('Content-Disposition', '').lower() + cdisposition = response.get("Content-Disposition", "").lower() - if cdisposition.split(';')[0].strip().lower() == 'attachment': + if cdisposition.split(";")[0].strip().lower() == "attachment": return False return True @@ -144,6 +155,7 @@ def should_add_browser_timing(response, transaction): # Response middleware for automatically inserting RUM header and # footer into HTML response returned by application + def browser_timing_insertion(response, transaction): # No point continuing if header is empty. This can occur if @@ -175,14 +187,15 @@ def html_to_be_inserted(): if result is not None: if transaction.settings.debug.log_autorum_middleware: - _logger.debug('RUM insertion from Django middleware ' - 'triggered. Bytes added was %r.', - len(result) - len(response.content)) + _logger.debug( + "RUM insertion from Django middleware triggered. Bytes added was %r.", + len(result) - len(response.content), + ) response.content = result - if response.get('Content-Length', None): - response['Content-Length'] = str(len(response.content)) + if response.get("Content-Length", None): + response["Content-Length"] = str(len(response.content)) return response @@ -192,18 +205,19 @@ def html_to_be_inserted(): # 'newrelic' will be automatically inserted into set of tag # libraries when performing step to instrument the middleware. + def newrelic_browser_timing_header(): from django.utils.safestring import mark_safe transaction = current_transaction() - return transaction and mark_safe(transaction.browser_timing_header()) or '' + return transaction and mark_safe(transaction.browser_timing_header()) or "" # nosec def newrelic_browser_timing_footer(): from django.utils.safestring import mark_safe transaction = current_transaction() - return transaction and mark_safe(transaction.browser_timing_footer()) or '' + return transaction and mark_safe(transaction.browser_timing_footer()) or "" # nosec # Addition of instrumentation for middleware. Can only do this @@ -256,9 +270,14 @@ def wrapper(wrapped, instance, args, kwargs): yield wrapper(wrapped) -def wrap_view_middleware(middleware): +# Because this is not being used in any version of Django that is +# within New Relic's support window, no tests will be added +# for this. However, value exists to keeping backwards compatible +# functionality, so instead of removing this instrumentation, this +# will be excluded from the coverage analysis. +def wrap_view_middleware(middleware): # pragma: no cover - # XXX This is no longer being used. The changes to strip the + # This is no longer being used. The changes to strip the # wrapper from the view handler when passed into the function # urlresolvers.reverse() solves most of the problems. To back # that up, the object wrapper now proxies various special @@ -293,7 +312,7 @@ def wrapper(wrapped, instance, args, kwargs): def _wrapped(request, view_func, view_args, view_kwargs): # This strips the view handler wrapper before call. - if hasattr(view_func, '_nr_last_object'): + if hasattr(view_func, "_nr_last_object"): view_func = view_func._nr_last_object return wrapped(request, view_func, view_args, view_kwargs) @@ -370,37 +389,28 @@ def insert_and_wrap_middleware(handler, *args, **kwargs): # priority than that for view handler so view handler # name always takes precedence. - if hasattr(handler, '_request_middleware'): - handler._request_middleware = list( - wrap_leading_middleware( - handler._request_middleware)) + if hasattr(handler, "_request_middleware"): + handler._request_middleware = list(wrap_leading_middleware(handler._request_middleware)) - if hasattr(handler, '_view_middleware'): - handler._view_middleware = list( - wrap_leading_middleware( - handler._view_middleware)) + if hasattr(handler, "_view_middleware"): + handler._view_middleware = list(wrap_leading_middleware(handler._view_middleware)) - if hasattr(handler, '_template_response_middleware'): + if hasattr(handler, "_template_response_middleware"): handler._template_response_middleware = list( - wrap_trailing_middleware( - handler._template_response_middleware)) + wrap_trailing_middleware(handler._template_response_middleware) + ) - if hasattr(handler, '_response_middleware'): - handler._response_middleware = list( - wrap_trailing_middleware( - handler._response_middleware)) + if hasattr(handler, "_response_middleware"): + handler._response_middleware = list(wrap_trailing_middleware(handler._response_middleware)) - if hasattr(handler, '_exception_middleware'): - handler._exception_middleware = list( - wrap_trailing_middleware( - handler._exception_middleware)) + if hasattr(handler, "_exception_middleware"): + handler._exception_middleware = list(wrap_trailing_middleware(handler._exception_middleware)) finally: lock.release() -def _nr_wrapper_GZipMiddleware_process_response_(wrapped, instance, args, - kwargs): +def _nr_wrapper_GZipMiddleware_process_response_(wrapped, instance, args, kwargs): transaction = current_transaction() @@ -433,36 +443,33 @@ def _nr_wrapper_BaseHandler_get_response_(wrapped, instance, args, kwargs): request = _bind_get_response(*args, **kwargs) - if hasattr(request, '_nr_exc_info'): + if hasattr(request, "_nr_exc_info"): notice_error(error=request._nr_exc_info, status_code=response.status_code) - delattr(request, '_nr_exc_info') + delattr(request, "_nr_exc_info") return response # Post import hooks for modules. + def instrument_django_core_handlers_base(module): # Attach a post function to load_middleware() method of # BaseHandler to trigger insertion of browser timing # middleware and wrapping of middleware for timing etc. - wrap_post_function(module, 'BaseHandler.load_middleware', - insert_and_wrap_middleware) + wrap_post_function(module, "BaseHandler.load_middleware", insert_and_wrap_middleware) - if six.PY3 and hasattr(module.BaseHandler, 'get_response_async'): - wrap_function_wrapper(module, 'BaseHandler.get_response_async', - _nr_wrapper_BaseHandler_get_response_async_) + if six.PY3 and hasattr(module.BaseHandler, "get_response_async"): + wrap_function_wrapper(module, "BaseHandler.get_response_async", _nr_wrapper_BaseHandler_get_response_async_) - wrap_function_wrapper(module, 'BaseHandler.get_response', - _nr_wrapper_BaseHandler_get_response_) + wrap_function_wrapper(module, "BaseHandler.get_response", _nr_wrapper_BaseHandler_get_response_) def instrument_django_gzip_middleware(module): - wrap_function_wrapper(module, 'GZipMiddleware.process_response', - _nr_wrapper_GZipMiddleware_process_response_) + wrap_function_wrapper(module, "GZipMiddleware.process_response", _nr_wrapper_GZipMiddleware_process_response_) def wrap_handle_uncaught_exception(middleware): @@ -506,10 +513,9 @@ def instrument_django_core_handlers_wsgi(module): import django - framework = ('Django', django.get_version()) + framework = ("Django", django.get_version()) - module.WSGIHandler.__call__ = WSGIApplicationWrapper( - module.WSGIHandler.__call__, framework=framework) + module.WSGIHandler.__call__ = WSGIApplicationWrapper(module.WSGIHandler.__call__, framework=framework) # Wrap handle_uncaught_exception() of WSGIHandler so that # can capture exception details of any exception which @@ -519,10 +525,10 @@ def instrument_django_core_handlers_wsgi(module): # exception, so last chance to do this as exception will not # propagate up to the WSGI application. - if hasattr(module.WSGIHandler, 'handle_uncaught_exception'): - module.WSGIHandler.handle_uncaught_exception = ( - wrap_handle_uncaught_exception( - module.WSGIHandler.handle_uncaught_exception)) + if hasattr(module.WSGIHandler, "handle_uncaught_exception"): + module.WSGIHandler.handle_uncaught_exception = wrap_handle_uncaught_exception( + module.WSGIHandler.handle_uncaught_exception + ) def wrap_view_handler(wrapped, priority=3): @@ -532,7 +538,7 @@ def wrap_view_handler(wrapped, priority=3): # called recursively. We flag that view handler was wrapped # using the '_nr_django_view_handler' attribute. - if hasattr(wrapped, '_nr_django_view_handler'): + if hasattr(wrapped, "_nr_django_view_handler"): return wrapped if hasattr(wrapped, "view_class"): @@ -584,7 +590,7 @@ def wrapper(wrapped, instance, args, kwargs): if transaction is None: return wrapped(*args, **kwargs) - if hasattr(transaction, '_nr_django_url_resolver'): + if hasattr(transaction, "_nr_django_url_resolver"): return wrapped(*args, **kwargs) # Tag the transaction so we know when we are in the top @@ -602,8 +608,7 @@ def _wrapped(path): if type(result) is tuple: callback, callback_args, callback_kwargs = result - result = (wrap_view_handler(callback, priority=5), - callback_args, callback_kwargs) + result = (wrap_view_handler(callback, priority=5), callback_args, callback_kwargs) else: result.func = wrap_view_handler(result.func, priority=5) @@ -636,8 +641,7 @@ def wrapper(wrapped, instance, args, kwargs): return wrap_view_handler(result, priority=priority) else: callback, param_dict = result - return (wrap_view_handler(callback, priority=priority), - param_dict) + return (wrap_view_handler(callback, priority=priority), param_dict) return FunctionWrapper(wrapped, wrapper) @@ -653,9 +657,10 @@ def wrap_url_reverse(wrapped): def wrapper(wrapped, instance, args, kwargs): def execute(viewname, *args, **kwargs): - if hasattr(viewname, '_nr_last_object'): + if hasattr(viewname, "_nr_last_object"): viewname = viewname._nr_last_object return wrapped(viewname, *args, **kwargs) + return execute(*args, **kwargs) return FunctionWrapper(wrapped, wrapper) @@ -672,20 +677,19 @@ def instrument_django_core_urlresolvers(module): # lost. We thus intercept it here so can capture that # traceback which is otherwise lost. - wrap_error_trace(module, 'get_callable') + wrap_error_trace(module, "get_callable") # Wrap methods which resolves a request to a view handler. # This can be called against a resolver initialised against # a custom URL conf associated with a specific request, or a # resolver which uses the default URL conf. - if hasattr(module, 'RegexURLResolver'): + if hasattr(module, "RegexURLResolver"): urlresolver = module.RegexURLResolver else: urlresolver = module.URLResolver - urlresolver.resolve = wrap_url_resolver( - urlresolver.resolve) + urlresolver.resolve = wrap_url_resolver(urlresolver.resolve) # Wrap methods which resolve error handlers. For 403 and 404 # we give these higher naming priority over any prior @@ -695,26 +699,22 @@ def instrument_django_core_urlresolvers(module): # handler in place so error details identify the correct # transaction. - if hasattr(urlresolver, 'resolve403'): - urlresolver.resolve403 = wrap_url_resolver_nnn( - urlresolver.resolve403, priority=3) + if hasattr(urlresolver, "resolve403"): + urlresolver.resolve403 = wrap_url_resolver_nnn(urlresolver.resolve403, priority=3) - if hasattr(urlresolver, 'resolve404'): - urlresolver.resolve404 = wrap_url_resolver_nnn( - urlresolver.resolve404, priority=3) + if hasattr(urlresolver, "resolve404"): + urlresolver.resolve404 = wrap_url_resolver_nnn(urlresolver.resolve404, priority=3) - if hasattr(urlresolver, 'resolve500'): - urlresolver.resolve500 = wrap_url_resolver_nnn( - urlresolver.resolve500, priority=1) + if hasattr(urlresolver, "resolve500"): + urlresolver.resolve500 = wrap_url_resolver_nnn(urlresolver.resolve500, priority=1) - if hasattr(urlresolver, 'resolve_error_handler'): - urlresolver.resolve_error_handler = wrap_url_resolver_nnn( - urlresolver.resolve_error_handler, priority=1) + if hasattr(urlresolver, "resolve_error_handler"): + urlresolver.resolve_error_handler = wrap_url_resolver_nnn(urlresolver.resolve_error_handler, priority=1) # Wrap function for performing reverse URL lookup to strip any # instrumentation wrapper when view handler is passed in. - if hasattr(module, 'reverse'): + if hasattr(module, "reverse"): module.reverse = wrap_url_reverse(module.reverse) @@ -723,7 +723,7 @@ def instrument_django_urls_base(module): # Wrap function for performing reverse URL lookup to strip any # instrumentation wrapper when view handler is passed in. - if hasattr(module, 'reverse'): + if hasattr(module, "reverse"): module.reverse = wrap_url_reverse(module.reverse) @@ -742,17 +742,15 @@ def instrument_django_template(module): def template_name(template, *args): return template.name - if hasattr(module.Template, '_render'): - wrap_function_trace(module, 'Template._render', - name=template_name, group='Template/Render') + if hasattr(module.Template, "_render"): + wrap_function_trace(module, "Template._render", name=template_name, group="Template/Render") else: - wrap_function_trace(module, 'Template.render', - name=template_name, group='Template/Render') + wrap_function_trace(module, "Template.render", name=template_name, group="Template/Render") # Django 1.8 no longer has module.libraries. As automatic way is not # preferred we can just skip this now. - if not hasattr(module, 'libraries'): + if not hasattr(module, "libraries"): return # Register template tags used for manual insertion of RUM @@ -766,12 +764,12 @@ def template_name(template, *args): library.simple_tag(newrelic_browser_timing_header) library.simple_tag(newrelic_browser_timing_footer) - module.libraries['django.templatetags.newrelic'] = library + module.libraries["django.templatetags.newrelic"] = library def wrap_template_block(wrapped): def wrapper(wrapped, instance, args, kwargs): - return FunctionTraceWrapper(wrapped, name=instance.name, group='Template/Block')(*args, **kwargs) + return FunctionTraceWrapper(wrapped, name=instance.name, group="Template/Block")(*args, **kwargs) return FunctionWrapper(wrapped, wrapper) @@ -812,11 +810,15 @@ def instrument_django_core_servers_basehttp(module): # instrumentation of the wsgiref module or some other means. def wrap_wsgi_application_entry_point(server, application, **kwargs): - return ((server, WSGIApplicationWrapper(application, - framework='Django'),), kwargs) + return ( + ( + server, + WSGIApplicationWrapper(application, framework="Django"), + ), + kwargs, + ) - if (not hasattr(module, 'simple_server') and - hasattr(module.ServerHandler, 'run')): + if not hasattr(module, "simple_server") and hasattr(module.ServerHandler, "run"): # Patch the server to make it work properly. @@ -833,11 +835,10 @@ def run(self, application): def close(self): if self.result is not None: try: - self.request_handler.log_request( - self.status.split(' ', 1)[0], self.bytes_sent) + self.request_handler.log_request(self.status.split(" ", 1)[0], self.bytes_sent) finally: try: - if hasattr(self.result, 'close'): + if hasattr(self.result, "close"): self.result.close() finally: self.result = None @@ -855,17 +856,16 @@ def close(self): # Now wrap it with our instrumentation. - wrap_in_function(module, 'ServerHandler.run', - wrap_wsgi_application_entry_point) + wrap_in_function(module, "ServerHandler.run", wrap_wsgi_application_entry_point) def instrument_django_contrib_staticfiles_views(module): - if not hasattr(module.serve, '_nr_django_view_handler'): + if not hasattr(module.serve, "_nr_django_view_handler"): module.serve = wrap_view_handler(module.serve, priority=3) def instrument_django_contrib_staticfiles_handlers(module): - wrap_transaction_name(module, 'StaticFilesHandler.serve') + wrap_transaction_name(module, "StaticFilesHandler.serve") def instrument_django_views_debug(module): @@ -878,10 +878,8 @@ def instrument_django_views_debug(module): # from a middleware or view handler in place so error # details identify the correct transaction. - module.technical_404_response = wrap_view_handler( - module.technical_404_response, priority=3) - module.technical_500_response = wrap_view_handler( - module.technical_500_response, priority=1) + module.technical_404_response = wrap_view_handler(module.technical_404_response, priority=3) + module.technical_500_response = wrap_view_handler(module.technical_500_response, priority=1) def resolve_view_handler(view, request): @@ -890,8 +888,7 @@ def resolve_view_handler(view, request): # duplicate the lookup mechanism. if request.method.lower() in view.http_method_names: - handler = getattr(view, request.method.lower(), - view.http_method_not_allowed) + handler = getattr(view, request.method.lower(), view.http_method_not_allowed) else: handler = view.http_method_not_allowed @@ -936,7 +933,7 @@ def _args(request, *args, **kwargs): priority = 4 - if transaction.group == 'Function': + if transaction.group == "Function": if transaction.name == callable_name(view): priority = 5 @@ -953,22 +950,22 @@ def instrument_django_views_generic_base(module): def instrument_django_http_multipartparser(module): - wrap_function_trace(module, 'MultiPartParser.parse') + wrap_function_trace(module, "MultiPartParser.parse") def instrument_django_core_mail(module): - wrap_function_trace(module, 'mail_admins') - wrap_function_trace(module, 'mail_managers') - wrap_function_trace(module, 'send_mail') + wrap_function_trace(module, "mail_admins") + wrap_function_trace(module, "mail_managers") + wrap_function_trace(module, "send_mail") def instrument_django_core_mail_message(module): - wrap_function_trace(module, 'EmailMessage.send') + wrap_function_trace(module, "EmailMessage.send") def _nr_wrapper_BaseCommand___init___(wrapped, instance, args, kwargs): instance.handle = FunctionTraceWrapper(instance.handle) - if hasattr(instance, 'handle_noargs'): + if hasattr(instance, "handle_noargs"): instance.handle_noargs = FunctionTraceWrapper(instance.handle_noargs) return wrapped(*args, **kwargs) @@ -982,29 +979,25 @@ def _args(argv, *args, **kwargs): subcommand = _argv[1] commands = django_settings.instrumentation.scripts.django_admin - startup_timeout = \ - django_settings.instrumentation.background_task.startup_timeout + startup_timeout = django_settings.instrumentation.background_task.startup_timeout if subcommand not in commands: return wrapped(*args, **kwargs) application = register_application(timeout=startup_timeout) - return BackgroundTaskWrapper(wrapped, application, subcommand, 'Django')(*args, **kwargs) + return BackgroundTaskWrapper(wrapped, application, subcommand, "Django")(*args, **kwargs) def instrument_django_core_management_base(module): - wrap_function_wrapper(module, 'BaseCommand.__init__', - _nr_wrapper_BaseCommand___init___) - wrap_function_wrapper(module, 'BaseCommand.run_from_argv', - _nr_wrapper_BaseCommand_run_from_argv_) + wrap_function_wrapper(module, "BaseCommand.__init__", _nr_wrapper_BaseCommand___init___) + wrap_function_wrapper(module, "BaseCommand.run_from_argv", _nr_wrapper_BaseCommand_run_from_argv_) @function_wrapper -def _nr_wrapper_django_inclusion_tag_wrapper_(wrapped, instance, - args, kwargs): +def _nr_wrapper_django_inclusion_tag_wrapper_(wrapped, instance, args, kwargs): - name = hasattr(wrapped, '__name__') and wrapped.__name__ + name = hasattr(wrapped, "__name__") and wrapped.__name__ if name is None: return wrapped(*args, **kwargs) @@ -1013,16 +1006,14 @@ def _nr_wrapper_django_inclusion_tag_wrapper_(wrapped, instance, tags = django_settings.instrumentation.templates.inclusion_tag - if '*' not in tags and name not in tags and qualname not in tags: + if "*" not in tags and name not in tags and qualname not in tags: return wrapped(*args, **kwargs) - return FunctionTraceWrapper(wrapped, name=name, group='Template/Tag')(*args, **kwargs) + return FunctionTraceWrapper(wrapped, name=name, group="Template/Tag")(*args, **kwargs) @function_wrapper -def _nr_wrapper_django_inclusion_tag_decorator_(wrapped, instance, - args, kwargs): - +def _nr_wrapper_django_inclusion_tag_decorator_(wrapped, instance, args, kwargs): def _bind_params(func, *args, **kwargs): return func, args, kwargs @@ -1033,63 +1024,56 @@ def _bind_params(func, *args, **kwargs): return wrapped(func, *_args, **_kwargs) -def _nr_wrapper_django_template_base_Library_inclusion_tag_(wrapped, - instance, args, kwargs): +def _nr_wrapper_django_template_base_Library_inclusion_tag_(wrapped, instance, args, kwargs): - return _nr_wrapper_django_inclusion_tag_decorator_( - wrapped(*args, **kwargs)) + return _nr_wrapper_django_inclusion_tag_decorator_(wrapped(*args, **kwargs)) @function_wrapper -def _nr_wrapper_django_template_base_InclusionNode_render_(wrapped, - instance, args, kwargs): +def _nr_wrapper_django_template_base_InclusionNode_render_(wrapped, instance, args, kwargs): if wrapped.__self__ is None: return wrapped(*args, **kwargs) - file_name = getattr(wrapped.__self__, '_nr_file_name', None) + file_name = getattr(wrapped.__self__, "_nr_file_name", None) if file_name is None: return wrapped(*args, **kwargs) name = wrapped.__self__._nr_file_name - return FunctionTraceWrapper(wrapped, name=name, group='Template/Include')(*args, **kwargs) + return FunctionTraceWrapper(wrapped, name=name, group="Template/Include")(*args, **kwargs) -def _nr_wrapper_django_template_base_generic_tag_compiler_(wrapped, instance, - args, kwargs): +def _nr_wrapper_django_template_base_generic_tag_compiler_(wrapped, instance, args, kwargs): if wrapped.__code__.co_argcount > 6: # Django > 1.3. - def _bind_params(parser, token, params, varargs, varkw, defaults, - name, takes_context, node_class, *args, **kwargs): + def _bind_params( + parser, token, params, varargs, varkw, defaults, name, takes_context, node_class, *args, **kwargs + ): return node_class + else: # Django <= 1.3. - def _bind_params(params, defaults, name, node_class, parser, token, - *args, **kwargs): + def _bind_params(params, defaults, name, node_class, parser, token, *args, **kwargs): return node_class node_class = _bind_params(*args, **kwargs) - if node_class.__name__ == 'InclusionNode': + if node_class.__name__ == "InclusionNode": result = wrapped(*args, **kwargs) - result.render = ( - _nr_wrapper_django_template_base_InclusionNode_render_( - result.render)) + result.render = _nr_wrapper_django_template_base_InclusionNode_render_(result.render) return result return wrapped(*args, **kwargs) -def _nr_wrapper_django_template_base_Library_tag_(wrapped, instance, - args, kwargs): - +def _nr_wrapper_django_template_base_Library_tag_(wrapped, instance, args, kwargs): def _bind_params(name=None, compile_function=None, *args, **kwargs): return compile_function @@ -1105,14 +1089,16 @@ def _get_node_class(compile_function): # Django >= 1.4 uses functools.partial if isinstance(compile_function, functools.partial): - node_class = compile_function.keywords.get('node_class') + node_class = compile_function.keywords.get("node_class") # Django < 1.4 uses their home-grown "curry" function, # not functools.partial. - if (hasattr(compile_function, 'func_closure') and - hasattr(compile_function, '__name__') and - compile_function.__name__ == '_curried'): + if ( + hasattr(compile_function, "func_closure") + and hasattr(compile_function, "__name__") + and compile_function.__name__ == "_curried" + ): # compile_function here is generic_tag_compiler(), which has been # curried. To get node_class, we first get the function obj, args, @@ -1121,19 +1107,20 @@ def _get_node_class(compile_function): # is not consistent from platform to platform, so we need to map # them to the variables in compile_function.__code__.co_freevars. - cells = dict(zip(compile_function.__code__.co_freevars, - (c.cell_contents for c in compile_function.func_closure))) + cells = dict( + zip(compile_function.__code__.co_freevars, (c.cell_contents for c in compile_function.func_closure)) + ) # node_class is the 4th arg passed to generic_tag_compiler() - if 'args' in cells and len(cells['args']) > 3: - node_class = cells['args'][3] + if "args" in cells and len(cells["args"]) > 3: + node_class = cells["args"][3] return node_class node_class = _get_node_class(compile_function) - if node_class is None or node_class.__name__ != 'InclusionNode': + if node_class is None or node_class.__name__ != "InclusionNode": return wrapped(*args, **kwargs) # Climb stack to find the file_name of the include template. @@ -1146,9 +1133,8 @@ def _get_node_class(compile_function): for i in range(1, stack_levels + 1): frame = sys._getframe(i) - if ('generic_tag_compiler' in frame.f_code.co_names and - 'file_name' in frame.f_code.co_freevars): - file_name = frame.f_locals.get('file_name') + if "generic_tag_compiler" in frame.f_code.co_names and "file_name" in frame.f_code.co_freevars: + file_name = frame.f_locals.get("file_name") if file_name is None: return wrapped(*args, **kwargs) @@ -1167,22 +1153,22 @@ def instrument_django_template_base(module): settings = global_settings() - if 'django.instrumentation.inclusion-tags.r1' in settings.feature_flag: + if "django.instrumentation.inclusion-tags.r1" in settings.feature_flag: - if hasattr(module, 'generic_tag_compiler'): - wrap_function_wrapper(module, 'generic_tag_compiler', - _nr_wrapper_django_template_base_generic_tag_compiler_) + if hasattr(module, "generic_tag_compiler"): + wrap_function_wrapper( + module, "generic_tag_compiler", _nr_wrapper_django_template_base_generic_tag_compiler_ + ) - if hasattr(module, 'Library'): - wrap_function_wrapper(module, 'Library.tag', - _nr_wrapper_django_template_base_Library_tag_) + if hasattr(module, "Library"): + wrap_function_wrapper(module, "Library.tag", _nr_wrapper_django_template_base_Library_tag_) - wrap_function_wrapper(module, 'Library.inclusion_tag', - _nr_wrapper_django_template_base_Library_inclusion_tag_) + wrap_function_wrapper( + module, "Library.inclusion_tag", _nr_wrapper_django_template_base_Library_inclusion_tag_ + ) def _nr_wrap_converted_middleware_(middleware, name): - @function_wrapper def _wrapper(wrapped, instance, args, kwargs): transaction = current_transaction() @@ -1197,9 +1183,7 @@ def _wrapper(wrapped, instance, args, kwargs): return _wrapper(middleware) -def _nr_wrapper_convert_exception_to_response_(wrapped, instance, args, - kwargs): - +def _nr_wrapper_convert_exception_to_response_(wrapped, instance, args, kwargs): def _bind_params(original_middleware, *args, **kwargs): return original_middleware @@ -1214,21 +1198,19 @@ def _bind_params(original_middleware, *args, **kwargs): def instrument_django_core_handlers_exception(module): - if hasattr(module, 'convert_exception_to_response'): - wrap_function_wrapper(module, 'convert_exception_to_response', - _nr_wrapper_convert_exception_to_response_) + if hasattr(module, "convert_exception_to_response"): + wrap_function_wrapper(module, "convert_exception_to_response", _nr_wrapper_convert_exception_to_response_) - if hasattr(module, 'handle_uncaught_exception'): - module.handle_uncaught_exception = ( - wrap_handle_uncaught_exception( - module.handle_uncaught_exception)) + if hasattr(module, "handle_uncaught_exception"): + module.handle_uncaught_exception = wrap_handle_uncaught_exception(module.handle_uncaught_exception) def instrument_django_core_handlers_asgi(module): import django - framework = ('Django', django.get_version()) + framework = ("Django", django.get_version()) - if hasattr(module, 'ASGIHandler'): + if hasattr(module, "ASGIHandler"): from newrelic.api.asgi_application import wrap_asgi_application - wrap_asgi_application(module, 'ASGIHandler.__call__', framework=framework) + + wrap_asgi_application(module, "ASGIHandler.__call__", framework=framework) diff --git a/newrelic/hooks/framework_flask.py b/newrelic/hooks/framework_flask.py index c0540a60d..6ef45e6af 100644 --- a/newrelic/hooks/framework_flask.py +++ b/newrelic/hooks/framework_flask.py @@ -166,7 +166,7 @@ def _nr_wrapper_error_handler_(wrapped, instance, args, kwargs): return FunctionTraceWrapper(wrapped, name=name)(*args, **kwargs) -def _nr_wrapper_Flask__register_error_handler_(wrapped, instance, args, kwargs): +def _nr_wrapper_Flask__register_error_handler_(wrapped, instance, args, kwargs): # pragma: no cover def _bind_params(key, code_or_exception, f): return key, code_or_exception, f @@ -189,7 +189,6 @@ def _bind_params(code_or_exception, f): def _nr_wrapper_Flask_try_trigger_before_first_request_functions_(wrapped, instance, args, kwargs): - transaction = current_transaction() if transaction is None: @@ -355,7 +354,6 @@ def _nr_wrapper_Blueprint_endpoint_(wrapped, instance, args, kwargs): @function_wrapper def _nr_wrapper_Blueprint_before_request_wrapped_(wrapped, instance, args, kwargs): - transaction = current_transaction() if transaction is None: diff --git a/newrelic/hooks/framework_graphql.py b/newrelic/hooks/framework_graphql.py index d261b2e9f..df86e6984 100644 --- a/newrelic/hooks/framework_graphql.py +++ b/newrelic/hooks/framework_graphql.py @@ -13,7 +13,10 @@ # limitations under the License. import logging +import sys +import time from collections import deque +from inspect import isawaitable from newrelic.api.error_trace import ErrorTrace from newrelic.api.function_trace import FunctionTrace @@ -22,7 +25,14 @@ from newrelic.api.transaction import current_transaction, ignore_transaction from newrelic.common.object_names import callable_name, parse_exc_info from newrelic.common.object_wrapper import function_wrapper, wrap_function_wrapper +from newrelic.common.package_version_utils import get_package_version from newrelic.core.graphql_utils import graphql_statement +from newrelic.hooks.framework_graphql_py3 import ( + nr_coro_execute_name_wrapper, + nr_coro_graphql_impl_wrapper, + nr_coro_resolver_error_wrapper, + nr_coro_resolver_wrapper, +) _logger = logging.getLogger(__name__) @@ -32,23 +42,8 @@ VERSION = None -def framework_version(): - """Framework version string.""" - global VERSION - if VERSION is None: - from graphql import __version__ as version - - VERSION = version - - return VERSION - - -def graphql_version(): - """Minor version tuple.""" - version = framework_version() - - # Take first two values in version to avoid ValueErrors with pre-releases (ex: 3.2.0a0) - return tuple(int(v) for v in version.split(".")[:2]) +GRAPHQL_VERSION = get_package_version("graphql-core") +major_version = int(GRAPHQL_VERSION.split(".")[0]) def ignore_graphql_duplicate_exception(exc, val, tb): @@ -98,10 +93,6 @@ def bind_operation_v3(operation, root_value): return operation -def bind_operation_v2(exe_context, operation, root_value): - return operation - - def wrap_execute_operation(wrapped, instance, args, kwargs): transaction = current_transaction() trace = current_trace() @@ -118,15 +109,9 @@ def wrap_execute_operation(wrapped, instance, args, kwargs): try: operation = bind_operation_v3(*args, **kwargs) except TypeError: - try: - operation = bind_operation_v2(*args, **kwargs) - except TypeError: - return wrapped(*args, **kwargs) + return wrapped(*args, **kwargs) - if graphql_version() < (3, 0): - execution_context = args[0] - else: - execution_context = instance + execution_context = instance trace.operation_name = get_node_value(operation, "name") or "" @@ -145,12 +130,17 @@ def wrap_execute_operation(wrapped, instance, args, kwargs): transaction.set_transaction_name(callable_name(wrapped), "GraphQL", priority=11) result = wrapped(*args, **kwargs) - if not execution_context.errors: - if hasattr(trace, "set_transaction_name"): + + def set_name(value=None): + if not execution_context.errors and hasattr(trace, "set_transaction_name"): # Operation trace sets transaction name trace.set_transaction_name(priority=14) + return value - return result + if isawaitable(result): + return nr_coro_execute_name_wrapper(wrapped, result, set_name) + else: + return set_name(result) def get_node_value(field, attr, subattr="value"): @@ -161,39 +151,25 @@ def get_node_value(field, attr, subattr="value"): def is_fragment_spread_node(field): - # Resolve version specific imports - try: - from graphql.language.ast import FragmentSpread - except ImportError: - from graphql import FragmentSpreadNode as FragmentSpread + from graphql.language.ast import FragmentSpreadNode - return isinstance(field, FragmentSpread) + return isinstance(field, FragmentSpreadNode) def is_fragment(field): - # Resolve version specific imports - try: - from graphql.language.ast import FragmentSpread, InlineFragment - except ImportError: - from graphql import FragmentSpreadNode as FragmentSpread - from graphql import InlineFragmentNode as InlineFragment - - _fragment_types = (InlineFragment, FragmentSpread) + from graphql.language.ast import FragmentSpreadNode, InlineFragmentNode + _fragment_types = (InlineFragmentNode, FragmentSpreadNode) return isinstance(field, _fragment_types) def is_named_fragment(field): - # Resolve version specific imports - try: - from graphql.language.ast import NamedType - except ImportError: - from graphql import NamedTypeNode as NamedType + from graphql.language.ast import NamedTypeNode return ( is_fragment(field) and getattr(field, "type_condition", None) is not None - and isinstance(field.type_condition, NamedType) + and isinstance(field.type_condition, NamedTypeNode) ) @@ -321,12 +297,25 @@ def wrap_resolver(wrapped, instance, args, kwargs): if transaction is None: return wrapped(*args, **kwargs) - name = callable_name(wrapped) + base_resolver = getattr(wrapped, "_nr_base_resolver", wrapped) + + name = callable_name(base_resolver) transaction.set_transaction_name(name, "GraphQL", priority=13) + trace = FunctionTrace(name, source=base_resolver) - with FunctionTrace(name, source=wrapped): - with ErrorTrace(ignore=ignore_graphql_duplicate_exception): - return wrapped(*args, **kwargs) + with ErrorTrace(ignore=ignore_graphql_duplicate_exception): + sync_start_time = time.time() + result = wrapped(*args, **kwargs) + + if isawaitable(result): + # Grab any async resolvers and wrap with traces + return nr_coro_resolver_error_wrapper( + wrapped, name, trace, ignore_graphql_duplicate_exception, result, transaction + ) + else: + with trace: + trace.start_time = sync_start_time + return result def wrap_error_handler(wrapped, instance, args, kwargs): @@ -368,19 +357,12 @@ def bind_resolve_field_v3(parent_type, source, field_nodes, path): return parent_type, field_nodes, path -def bind_resolve_field_v2(exe_context, parent_type, source, field_asts, parent_info, field_path): - return parent_type, field_asts, field_path - - def wrap_resolve_field(wrapped, instance, args, kwargs): transaction = current_transaction() if transaction is None: return wrapped(*args, **kwargs) - if graphql_version() < (3, 0): - bind_resolve_field = bind_resolve_field_v2 - else: - bind_resolve_field = bind_resolve_field_v3 + bind_resolve_field = bind_resolve_field_v3 try: parent_type, field_asts, field_path = bind_resolve_field(*args, **kwargs) @@ -390,18 +372,34 @@ def wrap_resolve_field(wrapped, instance, args, kwargs): field_name = field_asts[0].name.value field_def = parent_type.fields.get(field_name) field_return_type = str(field_def.type) if field_def else "" + if isinstance(field_path, list): + field_path = field_path[0] + else: + field_path = field_path.key - with GraphQLResolverTrace(field_name) as trace: - with ErrorTrace(ignore=ignore_graphql_duplicate_exception): - trace._add_agent_attribute("graphql.field.parentType", parent_type.name) - trace._add_agent_attribute("graphql.field.returnType", field_return_type) + trace = GraphQLResolverTrace( + field_name, field_parent_type=parent_type.name, field_return_type=field_return_type, field_path=field_path + ) + start_time = time.time() - if isinstance(field_path, list): - trace._add_agent_attribute("graphql.field.path", field_path[0]) - else: - trace._add_agent_attribute("graphql.field.path", field_path.key) + try: + result = wrapped(*args, **kwargs) + except Exception: + # Synchonous resolver with exception raised + with trace: + trace.start_time = start_time + notice_error(ignore=ignore_graphql_duplicate_exception) + raise - return wrapped(*args, **kwargs) + if isawaitable(result): + # Asynchronous resolvers (returned coroutines from non-coroutine functions) + # Return a coroutine that handles wrapping in a resolver trace + return nr_coro_resolver_wrapper(wrapped, trace, ignore_graphql_duplicate_exception, result) + else: + # Synchonous resolver with no exception raised + with trace: + trace.start_time = start_time + return result def bind_graphql_impl_query(schema, source, *args, **kwargs): @@ -428,11 +426,8 @@ def wrap_graphql_impl(wrapped, instance, args, kwargs): if not transaction: return wrapped(*args, **kwargs) - transaction.add_framework_info(name="GraphQL", version=framework_version()) - if graphql_version() < (3, 0): - bind_query = bind_execute_graphql_query - else: - bind_query = bind_graphql_impl_query + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) + bind_query = bind_graphql_impl_query try: schema, query = bind_query(*args, **kwargs) @@ -444,17 +439,34 @@ def wrap_graphql_impl(wrapped, instance, args, kwargs): transaction.set_transaction_name(callable_name(wrapped), "GraphQL", priority=10) - with GraphQLOperationTrace() as trace: - trace.statement = graphql_statement(query) + trace = GraphQLOperationTrace() + + trace.statement = graphql_statement(query) - # Handle Schemas created from frameworks - if hasattr(schema, "_nr_framework"): - framework = schema._nr_framework - trace.product = framework[0] - transaction.add_framework_info(name=framework[0], version=framework[1]) + # Handle Schemas created from frameworks + if hasattr(schema, "_nr_framework"): + framework = schema._nr_framework + trace.product = framework[0] + transaction.add_framework_info(name=framework[0], version=framework[1]) + # Trace must be manually started and stopped to ensure it exists prior to and during the entire duration of the query. + # Otherwise subsequent instrumentation will not be able to find an operation trace and will have issues. + trace.__enter__() + try: with ErrorTrace(ignore=ignore_graphql_duplicate_exception): result = wrapped(*args, **kwargs) + except Exception as e: + # Execution finished synchronously, exit immediately. + trace.__exit__(*sys.exc_info()) + raise + else: + if isawaitable(result): + # Asynchronous implementations + # Return a coroutine that handles closing the operation trace + return nr_coro_graphql_impl_wrapper(wrapped, trace, ignore_graphql_duplicate_exception, result) + else: + # Execution finished synchronously, exit immediately. + trace.__exit__(None, None, None) return result @@ -480,11 +492,15 @@ def instrument_graphql_execute(module): def instrument_graphql_execution_utils(module): + if major_version == 2: + return if hasattr(module, "ExecutionContext"): wrap_function_wrapper(module, "ExecutionContext.__init__", wrap_executor_context_init) def instrument_graphql_execution_middleware(module): + if major_version == 2: + return if hasattr(module, "get_middleware_resolvers"): wrap_function_wrapper(module, "get_middleware_resolvers", wrap_get_middleware_resolvers) if hasattr(module, "MiddlewareManager"): @@ -492,20 +508,26 @@ def instrument_graphql_execution_middleware(module): def instrument_graphql_error_located_error(module): + if major_version == 2: + return if hasattr(module, "located_error"): wrap_function_wrapper(module, "located_error", wrap_error_handler) def instrument_graphql_validate(module): + if major_version == 2: + return wrap_function_wrapper(module, "validate", wrap_validate) def instrument_graphql(module): + if major_version == 2: + return if hasattr(module, "graphql_impl"): wrap_function_wrapper(module, "graphql_impl", wrap_graphql_impl) - if hasattr(module, "execute_graphql"): - wrap_function_wrapper(module, "execute_graphql", wrap_graphql_impl) def instrument_graphql_parser(module): + if major_version == 2: + return wrap_function_wrapper(module, "parse", wrap_parse) diff --git a/newrelic/hooks/framework_graphql_py3.py b/newrelic/hooks/framework_graphql_py3.py new file mode 100644 index 000000000..3931aa6ed --- /dev/null +++ b/newrelic/hooks/framework_graphql_py3.py @@ -0,0 +1,68 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import sys + +from newrelic.api.error_trace import ErrorTrace +from newrelic.api.function_trace import FunctionTrace + + +def nr_coro_execute_name_wrapper(wrapped, result, set_name): + @functools.wraps(wrapped) + async def _nr_coro_execute_name_wrapper(): + result_ = await result + set_name() + return result_ + + return _nr_coro_execute_name_wrapper() + + +def nr_coro_resolver_error_wrapper(wrapped, name, trace, ignore, result, transaction): + @functools.wraps(wrapped) + async def _nr_coro_resolver_error_wrapper(): + with trace: + with ErrorTrace(ignore=ignore): + try: + return await result + except Exception: + transaction.set_transaction_name(name, "GraphQL", priority=15) + raise + + return _nr_coro_resolver_error_wrapper() + + +def nr_coro_resolver_wrapper(wrapped, trace, ignore, result): + @functools.wraps(wrapped) + async def _nr_coro_resolver_wrapper(): + with trace: + with ErrorTrace(ignore=ignore): + return await result + + return _nr_coro_resolver_wrapper() + +def nr_coro_graphql_impl_wrapper(wrapped, trace, ignore, result): + @functools.wraps(wrapped) + async def _nr_coro_graphql_impl_wrapper(): + try: + with ErrorTrace(ignore=ignore): + result_ = await result + except: + trace.__exit__(*sys.exc_info()) + raise + else: + trace.__exit__(None, None, None) + return result_ + + + return _nr_coro_graphql_impl_wrapper() \ No newline at end of file diff --git a/newrelic/hooks/framework_strawberry.py b/newrelic/hooks/framework_strawberry.py index 92a0ea8b4..e6d06bb04 100644 --- a/newrelic/hooks/framework_strawberry.py +++ b/newrelic/hooks/framework_strawberry.py @@ -16,20 +16,14 @@ from newrelic.api.error_trace import ErrorTrace from newrelic.api.graphql_trace import GraphQLOperationTrace from newrelic.api.transaction import current_transaction -from newrelic.api.transaction_name import TransactionNameWrapper from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import wrap_function_wrapper +from newrelic.common.package_version_utils import get_package_version from newrelic.core.graphql_utils import graphql_statement -from newrelic.hooks.framework_graphql import ( - framework_version as graphql_framework_version, -) -from newrelic.hooks.framework_graphql import ignore_graphql_duplicate_exception +from newrelic.hooks.framework_graphql import GRAPHQL_VERSION, ignore_graphql_duplicate_exception - -def framework_details(): - import strawberry - - return ("Strawberry", getattr(strawberry, "__version__", None)) +STRAWBERRY_GRAPHQL_VERSION = get_package_version("strawberry-graphql") +strawberry_version_tuple = tuple(map(int, STRAWBERRY_GRAPHQL_VERSION.split("."))) def bind_execute(query, *args, **kwargs): @@ -47,9 +41,8 @@ def wrap_execute_sync(wrapped, instance, args, kwargs): except TypeError: return wrapped(*args, **kwargs) - framework = framework_details() - transaction.add_framework_info(name=framework[0], version=framework[1]) - transaction.add_framework_info(name="GraphQL", version=graphql_framework_version()) + transaction.add_framework_info(name="Strawberry", version=STRAWBERRY_GRAPHQL_VERSION) + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) if hasattr(query, "body"): query = query.body @@ -74,9 +67,8 @@ async def wrap_execute(wrapped, instance, args, kwargs): except TypeError: return await wrapped(*args, **kwargs) - framework = framework_details() - transaction.add_framework_info(name=framework[0], version=framework[1]) - transaction.add_framework_info(name="GraphQL", version=graphql_framework_version()) + transaction.add_framework_info(name="Strawberry", version=STRAWBERRY_GRAPHQL_VERSION) + transaction.add_framework_info(name="GraphQL", version=GRAPHQL_VERSION) if hasattr(query, "body"): query = query.body @@ -98,19 +90,20 @@ def wrap_from_resolver(wrapped, instance, args, kwargs): result = wrapped(*args, **kwargs) try: - field = bind_from_resolver(*args, **kwargs) + field = bind_from_resolver(*args, **kwargs) except TypeError: pass else: if hasattr(field, "base_resolver"): if hasattr(field.base_resolver, "wrapped_func"): - resolver_name = callable_name(field.base_resolver.wrapped_func) - result = TransactionNameWrapper(result, resolver_name, "GraphQL", priority=13) + result._nr_base_resolver = field.base_resolver.wrapped_func return result def instrument_strawberry_schema(module): + if strawberry_version_tuple < (0, 23, 3): + return if hasattr(module, "Schema"): if hasattr(module.Schema, "execute"): wrap_function_wrapper(module, "Schema.execute", wrap_execute) @@ -119,11 +112,15 @@ def instrument_strawberry_schema(module): def instrument_strawberry_asgi(module): + if strawberry_version_tuple < (0, 23, 3): + return if hasattr(module, "GraphQL"): - wrap_asgi_application(module, "GraphQL.__call__", framework=framework_details()) + wrap_asgi_application(module, "GraphQL.__call__", framework=("Strawberry", STRAWBERRY_GRAPHQL_VERSION)) def instrument_strawberry_schema_converter(module): + if strawberry_version_tuple < (0, 23, 3): + return if hasattr(module, "GraphQLCoreConverter"): if hasattr(module.GraphQLCoreConverter, "from_resolver"): wrap_function_wrapper(module, "GraphQLCoreConverter.from_resolver", wrap_from_resolver) diff --git a/newrelic/hooks/logger_structlog.py b/newrelic/hooks/logger_structlog.py new file mode 100644 index 000000000..e652a795c --- /dev/null +++ b/newrelic/hooks/logger_structlog.py @@ -0,0 +1,86 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from newrelic.common.object_wrapper import wrap_function_wrapper +from newrelic.api.transaction import current_transaction, record_log_event +from newrelic.core.config import global_settings +from newrelic.api.application import application_instance +from newrelic.hooks.logger_logging import add_nr_linking_metadata + + +def normalize_level_name(method_name): + # Look up level number for method name, using result to look up level name for that level number. + # Convert result to upper case, and default to UNKNOWN in case of errors or missing values. + try: + from structlog._log_levels import _LEVEL_TO_NAME, _NAME_TO_LEVEL + return _LEVEL_TO_NAME[_NAME_TO_LEVEL[method_name]].upper() + except Exception: + return "UNKNOWN" + + +def bind_process_event(method_name, event, event_kw): + return method_name, event, event_kw + + +def wrap__process_event(wrapped, instance, args, kwargs): + try: + method_name, event, event_kw = bind_process_event(*args, **kwargs) + except TypeError: + return wrapped(*args, **kwargs) + + original_message = event # Save original undecorated message + + transaction = current_transaction() + + if transaction: + settings = transaction.settings + else: + settings = global_settings() + + # Return early if application logging not enabled + if settings and settings.application_logging and settings.application_logging.enabled: + if settings.application_logging.local_decorating and settings.application_logging.local_decorating.enabled: + event = add_nr_linking_metadata(event) + + # Send log to processors for filtering, allowing any DropEvent exceptions that occur to prevent instrumentation from recording the log event. + result = wrapped(method_name, event, event_kw) + + level_name = normalize_level_name(method_name) + + if settings.application_logging.metrics and settings.application_logging.metrics.enabled: + if transaction: + transaction.record_custom_metric("Logging/lines", {"count": 1}) + transaction.record_custom_metric("Logging/lines/%s" % level_name, {"count": 1}) + else: + application = application_instance(activate=False) + if application and application.enabled: + application.record_custom_metric("Logging/lines", {"count": 1}) + application.record_custom_metric("Logging/lines/%s" % level_name, {"count": 1}) + + if settings.application_logging.forwarding and settings.application_logging.forwarding.enabled: + try: + record_log_event(original_message, level_name) + + except Exception: + pass + + # Return the result from wrapped after we've recorded the resulting log event. + return result + + return wrapped(*args, **kwargs) + + +def instrument_structlog__base(module): + if hasattr(module, "BoundLoggerBase") and hasattr(module.BoundLoggerBase, "_process_event"): + wrap_function_wrapper(module, "BoundLoggerBase._process_event", wrap__process_event) diff --git a/tests/agent_features/_test_async_coroutine_trace.py b/tests/agent_features/_test_async_coroutine_trace.py index 51b81f5f6..1250b8c25 100644 --- a/tests/agent_features/_test_async_coroutine_trace.py +++ b/tests/agent_features/_test_async_coroutine_trace.py @@ -28,6 +28,7 @@ from newrelic.api.datastore_trace import datastore_trace from newrelic.api.external_trace import external_trace from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace from newrelic.api.memcache_trace import memcache_trace from newrelic.api.message_trace import message_trace @@ -41,6 +42,8 @@ (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), ], ) def test_awaitable_timing(event_loop, trace, metric): @@ -79,6 +82,8 @@ def _test(): (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), ], ) @pytest.mark.parametrize("yield_from", [True, False]) diff --git a/tests/agent_features/_test_async_generator_trace.py b/tests/agent_features/_test_async_generator_trace.py new file mode 100644 index 000000000..30b970c37 --- /dev/null +++ b/tests/agent_features/_test_async_generator_trace.py @@ -0,0 +1,548 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import sys +import time + +import pytest +from testing_support.fixtures import capture_transaction_metrics, validate_tt_parenting +from testing_support.validators.validate_transaction_errors import ( + validate_transaction_errors, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.api.database_trace import database_trace +from newrelic.api.datastore_trace import datastore_trace +from newrelic.api.external_trace import external_trace +from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace +from newrelic.api.memcache_trace import memcache_trace +from newrelic.api.message_trace import message_trace + +asyncio = pytest.importorskip("asyncio") + + +@pytest.mark.parametrize( + "trace,metric", + [ + (functools.partial(function_trace, name="simple_gen"), "Function/simple_gen"), + (functools.partial(external_trace, library="lib", url="http://foo.com"), "External/foo.com/lib/"), + (functools.partial(database_trace, "select * from foo"), "Datastore/statement/None/foo/select"), + (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), + (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), + (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), + ], +) +def test_async_generator_timing(event_loop, trace, metric): + @trace() + async def simple_gen(): + time.sleep(0.1) + yield + time.sleep(0.1) + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @validate_transaction_metrics( + "test_async_generator_timing", background_task=True, scoped_metrics=[(metric, 1)], rollup_metrics=[(metric, 1)] + ) + @background_task(name="test_async_generator_timing") + def _test_async_generator_timing(): + async def _test(): + async for _ in simple_gen(): + pass + + event_loop.run_until_complete(_test()) + _test_async_generator_timing() + + # Check that coroutines time the total call time (including pauses) + metric_key = (metric, "") + assert full_metrics[metric_key].total_call_time >= 0.2 + + +class MyException(Exception): + pass + + +@validate_transaction_metrics( + "test_async_generator_error", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +@validate_transaction_errors(errors=["_test_async_generator_trace:MyException"]) +def test_async_generator_error(event_loop): + @function_trace(name="agen") + async def agen(): + yield + + @background_task(name="test_async_generator_error") + async def _test(): + gen = agen() + await gen.asend(None) + await gen.athrow(MyException) + + with pytest.raises(MyException): + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_async_generator_caught_exception", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +@validate_transaction_errors(errors=[]) +def test_async_generator_caught_exception(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(2): + time.sleep(0.1) + try: + yield + except ValueError: + pass + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @background_task(name="test_async_generator_caught_exception") + def _test_async_generator_caught_exception(): + async def _test(): + gen = agen() + # kickstart the generator (the try/except logic is inside the + # generator) + await gen.asend(None) + await gen.athrow(ValueError) + + # consume the generator + async for _ in gen: + pass + + # The ValueError should not be reraised + event_loop.run_until_complete(_test()) + _test_async_generator_caught_exception() + + assert full_metrics[("Function/agen", "")].total_call_time >= 0.2 + + +@validate_transaction_metrics( + "test_async_generator_handles_terminal_nodes", + background_task=True, + scoped_metrics=[("Function/parent", 1), ("Function/agen", None)], + rollup_metrics=[("Function/parent", 1), ("Function/agen", None)], +) +def test_async_generator_handles_terminal_nodes(event_loop): + # sometimes coroutines can be called underneath terminal nodes + # In this case, the trace shouldn't actually be created and we also + # shouldn't get any errors + + @function_trace(name="agen") + async def agen(): + yield + time.sleep(0.1) + + @function_trace(name="parent", terminal=True) + async def parent(): + # parent calls child + async for _ in agen(): + pass + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @background_task(name="test_async_generator_handles_terminal_nodes") + def _test_async_generator_handles_terminal_nodes(): + async def _test(): + await parent() + + event_loop.run_until_complete(_test()) + _test_async_generator_handles_terminal_nodes() + + metric_key = ("Function/parent", "") + assert full_metrics[metric_key].total_exclusive_call_time >= 0.1 + + +@validate_transaction_metrics( + "test_async_generator_close_ends_trace", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_async_generator_close_ends_trace(event_loop): + @function_trace(name="agen") + async def agen(): + yield + + @background_task(name="test_async_generator_close_ends_trace") + async def _test(): + gen = agen() + + # kickstart the coroutine + await gen.asend(None) + + # trace should be ended/recorded by close + await gen.aclose() + + # We may call gen.close as many times as we want + await gen.aclose() + + event_loop.run_until_complete(_test()) + +@validate_tt_parenting( + ( + "TransactionNode", + [ + ( + "FunctionNode", + [ + ("FunctionNode", []), + ], + ), + ], + ) +) +@validate_transaction_metrics( + "test_async_generator_parents", + background_task=True, + scoped_metrics=[("Function/child", 1), ("Function/parent", 1)], + rollup_metrics=[("Function/child", 1), ("Function/parent", 1)], +) +def test_async_generator_parents(event_loop): + @function_trace(name="child") + async def child(): + yield + time.sleep(0.1) + yield + + @function_trace(name="parent") + async def parent(): + time.sleep(0.1) + yield + async for _ in child(): + pass + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @background_task(name="test_async_generator_parents") + def _test_async_generator_parents(): + async def _test(): + async for _ in parent(): + pass + + event_loop.run_until_complete(_test()) + _test_async_generator_parents() + + # Check that the child time is subtracted from the parent time (parenting + # relationship is correctly established) + key = ("Function/parent", "") + assert full_metrics[key].total_exclusive_call_time < 0.2 + + +@validate_transaction_metrics( + "test_asend_receives_a_value", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_asend_receives_a_value(event_loop): + _received = [] + @function_trace(name="agen") + async def agen(): + value = yield + _received.append(value) + yield value + + @background_task(name="test_asend_receives_a_value") + async def _test(): + gen = agen() + + # kickstart the coroutine + await gen.asend(None) + + assert await gen.asend("foobar") == "foobar" + assert _received and _received[0] == "foobar" + + # finish consumption of the coroutine if necessary + async for _ in gen: + pass + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_athrow_yields_a_value", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_athrow_yields_a_value(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(2): + try: + yield + except MyException: + yield "foobar" + + @background_task(name="test_athrow_yields_a_value") + async def _test(): + gen = agen() + + # kickstart the coroutine + await gen.asend(None) + + assert await gen.athrow(MyException) == "foobar" + + # finish consumption of the coroutine if necessary + async for _ in gen: + pass + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_multiple_throws_yield_a_value", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_multiple_throws_yield_a_value(event_loop): + @function_trace(name="agen") + async def agen(): + value = None + for _ in range(4): + try: + yield value + value = "bar" + except MyException: + value = "foo" + + + @background_task(name="test_multiple_throws_yield_a_value") + async def _test(): + gen = agen() + + # kickstart the coroutine + assert await gen.asend(None) is None + assert await gen.athrow(MyException) == "foo" + assert await gen.athrow(MyException) == "foo" + assert await gen.asend(None) == "bar" + + # finish consumption of the coroutine if necessary + async for _ in gen: + pass + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_athrow_does_not_yield_a_value", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_athrow_does_not_yield_a_value(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(2): + try: + yield + except MyException: + return + + @background_task(name="test_athrow_does_not_yield_a_value") + async def _test(): + gen = agen() + + # kickstart the coroutine + await gen.asend(None) + + # async generator will raise StopAsyncIteration + with pytest.raises(StopAsyncIteration): + await gen.athrow(MyException) + + + event_loop.run_until_complete(_test()) + + +@pytest.mark.parametrize( + "trace", + [ + function_trace(name="simple_gen"), + external_trace(library="lib", url="http://foo.com"), + database_trace("select * from foo"), + datastore_trace("lib", "foo", "bar"), + message_trace("lib", "op", "typ", "name"), + memcache_trace("cmd"), + ], +) +def test_async_generator_functions_outside_of_transaction(event_loop, trace): + @trace + async def agen(): + for _ in range(2): + yield "foo" + + async def _test(): + assert [_ async for _ in agen()] == ["foo", "foo"] + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_catching_generator_exit_causes_runtime_error", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_catching_generator_exit_causes_runtime_error(event_loop): + @function_trace(name="agen") + async def agen(): + try: + yield + except GeneratorExit: + yield + + @background_task(name="test_catching_generator_exit_causes_runtime_error") + async def _test(): + gen = agen() + + # kickstart the coroutine (we're inside the try now) + await gen.asend(None) + + # Generators cannot catch generator exit exceptions (which are injected by + # close). This will result in a runtime error. + with pytest.raises(RuntimeError): + await gen.aclose() + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_async_generator_time_excludes_creation_time", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_async_generator_time_excludes_creation_time(event_loop): + @function_trace(name="agen") + async def agen(): + yield + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @background_task(name="test_async_generator_time_excludes_creation_time") + def _test_async_generator_time_excludes_creation_time(): + async def _test(): + gen = agen() + time.sleep(0.1) + async for _ in gen: + pass + + event_loop.run_until_complete(_test()) + _test_async_generator_time_excludes_creation_time() + + # check that the trace does not include the time between creation and + # consumption + assert full_metrics[("Function/agen", "")].total_call_time < 0.1 + + +@validate_transaction_metrics( + "test_complete_async_generator", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +@background_task(name="test_complete_async_generator") +def test_complete_async_generator(event_loop): + @function_trace(name="agen") + async def agen(): + for i in range(5): + yield i + + async def _test(): + gen = agen() + assert [x async for x in gen] == [x for x in range(5)] + + event_loop.run_until_complete(_test()) + + +@pytest.mark.parametrize("nr_transaction", [True, False]) +def test_incomplete_async_generator(event_loop, nr_transaction): + @function_trace(name="agen") + async def agen(): + for _ in range(5): + yield + + def _test_incomplete_async_generator(): + async def _test(): + c = agen() + + async for _ in c: + break + + if nr_transaction: + _test = background_task(name="test_incomplete_async_generator")(_test) + + event_loop.run_until_complete(_test()) + + if nr_transaction: + _test_incomplete_async_generator = validate_transaction_metrics( + "test_incomplete_async_generator", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], + )(_test_incomplete_async_generator) + + _test_incomplete_async_generator() + + +def test_incomplete_async_generator_transaction_exited(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(5): + yield + + @validate_transaction_metrics( + "test_incomplete_async_generator", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], + ) + def _test_incomplete_async_generator(): + c = agen() + @background_task(name="test_incomplete_async_generator") + async def _test(): + async for _ in c: + break + + event_loop.run_until_complete(_test()) + + # Remove generator after transaction completes + del c + + _test_incomplete_async_generator() diff --git a/tests/agent_features/test_async_generator_trace.py b/tests/agent_features/test_async_generator_trace.py new file mode 100644 index 000000000..208cf1588 --- /dev/null +++ b/tests/agent_features/test_async_generator_trace.py @@ -0,0 +1,19 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +# Async Generators were introduced in Python 3.6, but some APIs weren't completely stable until Python 3.7. +if sys.version_info >= (3, 7): + from _test_async_generator_trace import * # NOQA diff --git a/tests/agent_features/test_async_wrapper_detection.py b/tests/agent_features/test_async_wrapper_detection.py new file mode 100644 index 000000000..bb1fd3f1e --- /dev/null +++ b/tests/agent_features/test_async_wrapper_detection.py @@ -0,0 +1,102 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import functools +import time + +from newrelic.api.background_task import background_task +from newrelic.api.database_trace import database_trace +from newrelic.api.datastore_trace import datastore_trace +from newrelic.api.external_trace import external_trace +from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace +from newrelic.api.memcache_trace import memcache_trace +from newrelic.api.message_trace import message_trace + +from newrelic.common.async_wrapper import generator_wrapper + +from testing_support.fixtures import capture_transaction_metrics +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +trace_metric_cases = [ + (functools.partial(function_trace, name="simple_gen"), "Function/simple_gen"), + (functools.partial(external_trace, library="lib", url="http://foo.com"), "External/foo.com/lib/"), + (functools.partial(database_trace, "select * from foo"), "Datastore/statement/None/foo/select"), + (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), + (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), + (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), +] + + +@pytest.mark.parametrize("trace,metric", trace_metric_cases) +def test_automatic_generator_trace_wrapper(trace, metric): + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @validate_transaction_metrics( + "test_automatic_generator_trace_wrapper", background_task=True, scoped_metrics=[(metric, 1)], rollup_metrics=[(metric, 1)] + ) + @background_task(name="test_automatic_generator_trace_wrapper") + def _test(): + @trace() + def gen(): + time.sleep(0.1) + yield + time.sleep(0.1) + + for _ in gen(): + pass + + _test() + + # Check that generators time the total call time (including pauses) + metric_key = (metric, "") + assert full_metrics[metric_key].total_call_time >= 0.2 + + +@pytest.mark.parametrize("trace,metric", trace_metric_cases) +def test_manual_generator_trace_wrapper(trace, metric): + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @validate_transaction_metrics( + "test_automatic_generator_trace_wrapper", background_task=True, scoped_metrics=[(metric, 1)], rollup_metrics=[(metric, 1)] + ) + @background_task(name="test_automatic_generator_trace_wrapper") + def _test(): + @trace(async_wrapper=generator_wrapper) + def wrapper_func(): + """Function that returns a generator object, obscuring the automatic introspection of async_wrapper()""" + def gen(): + time.sleep(0.1) + yield + time.sleep(0.1) + return gen() + + for _ in wrapper_func(): + pass + + _test() + + # Check that generators time the total call time (including pauses) + metric_key = (metric, "") + assert full_metrics[metric_key].total_call_time >= 0.2 diff --git a/tests/agent_features/test_coroutine_trace.py b/tests/agent_features/test_coroutine_trace.py index 36e365bc4..2043f1326 100644 --- a/tests/agent_features/test_coroutine_trace.py +++ b/tests/agent_features/test_coroutine_trace.py @@ -31,6 +31,7 @@ from newrelic.api.datastore_trace import datastore_trace from newrelic.api.external_trace import external_trace from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace from newrelic.api.memcache_trace import memcache_trace from newrelic.api.message_trace import message_trace @@ -47,6 +48,8 @@ (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), ], ) def test_coroutine_timing(trace, metric): @@ -337,6 +340,37 @@ def coro(): pass +@validate_transaction_metrics( + "test_multiple_throws_yield_a_value", + background_task=True, + scoped_metrics=[("Function/coro", 1)], + rollup_metrics=[("Function/coro", 1)], +) +@background_task(name="test_multiple_throws_yield_a_value") +def test_multiple_throws_yield_a_value(): + @function_trace(name="coro") + def coro(): + value = None + for _ in range(4): + try: + yield value + value = "bar" + except MyException: + value = "foo" + + c = coro() + + # kickstart the coroutine + assert next(c) is None + assert c.throw(MyException) == "foo" + assert c.throw(MyException) == "foo" + assert next(c) == "bar" + + # finish consumption of the coroutine if necessary + for _ in c: + pass + + @pytest.mark.parametrize( "trace", [ diff --git a/tests/agent_features/test_datastore_trace.py b/tests/agent_features/test_datastore_trace.py new file mode 100644 index 000000000..08067e040 --- /dev/null +++ b/tests/agent_features/test_datastore_trace.py @@ -0,0 +1,89 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from testing_support.validators.validate_datastore_trace_inputs import ( + validate_datastore_trace_inputs, +) + +from newrelic.api.background_task import background_task +from newrelic.api.datastore_trace import DatastoreTrace, DatastoreTraceWrapper + + +@validate_datastore_trace_inputs( + operation="test_operation", + target="test_target", + host="test_host", + port_path_or_id="test_port", + database_name="test_db_name", +) +@background_task() +def test_dt_trace_all_args(): + with DatastoreTrace( + product="Agent Features", + target="test_target", + operation="test_operation", + host="test_host", + port_path_or_id="test_port", + database_name="test_db_name", + ): + pass + + +@validate_datastore_trace_inputs(operation=None, target=None, host=None, port_path_or_id=None, database_name=None) +@background_task() +def test_dt_trace_empty(): + with DatastoreTrace(product=None, target=None, operation=None): + pass + + +@background_task() +def test_dt_trace_callable_args(): + def product_callable(): + return "Agent Features" + + def target_callable(): + return "test_target" + + def operation_callable(): + return "test_operation" + + def host_callable(): + return "test_host" + + def port_path_id_callable(): + return "test_port" + + def db_name_callable(): + return "test_db_name" + + @validate_datastore_trace_inputs( + operation="test_operation", + target="test_target", + host="test_host", + port_path_or_id="test_port", + database_name="test_db_name", + ) + def _test(): + pass + + wrapped_fn = DatastoreTraceWrapper( + _test, + product=product_callable, + target=target_callable, + operation=operation_callable, + host=host_callable, + port_path_or_id=port_path_id_callable, + database_name=db_name_callable, + ) + wrapped_fn() diff --git a/tests/agent_features/test_span_events.py b/tests/agent_features/test_span_events.py index b9c04a8c8..05e375ff3 100644 --- a/tests/agent_features/test_span_events.py +++ b/tests/agent_features/test_span_events.py @@ -141,7 +141,6 @@ def test_each_span_type(trace_type, args): ) @background_task(name="test_each_span_type") def _test(): - transaction = current_transaction() transaction._sampled = True @@ -307,7 +306,6 @@ def _test(): } ) def test_external_span_limits(kwarg_override, attr_override): - exact_intrinsics = { "type": "Span", "sampled": True, @@ -364,7 +362,6 @@ def _test(): } ) def test_datastore_span_limits(kwarg_override, attribute_override): - exact_intrinsics = { "type": "Span", "sampled": True, @@ -416,10 +413,6 @@ def _test(): @pytest.mark.parametrize("span_events_enabled", (False, True)) def test_collect_span_events_override(collect_span_events, span_events_enabled): spans_expected = collect_span_events and span_events_enabled - # if collect_span_events and span_events_enabled: - # spans_expected = True - # else: - # spans_expected = False span_count = 2 if spans_expected else 0 @@ -509,7 +502,6 @@ def __exit__(self, *args): ) @pytest.mark.parametrize("exclude_attributes", (True, False)) def test_span_event_user_attributes(trace_type, args, exclude_attributes): - _settings = { "distributed_tracing.enabled": True, "span_events.enabled": True, @@ -626,7 +618,6 @@ def _test(): ), ) def test_span_event_error_attributes_notice_error(trace_type, args): - _settings = { "distributed_tracing.enabled": True, "span_events.enabled": True, @@ -674,7 +665,6 @@ def _test(): ), ) def test_span_event_error_attributes_observed(trace_type, args): - error = ValueError("whoops") exact_agents = { diff --git a/tests/agent_unittests/test_agent_protocol.py b/tests/agent_unittests/test_agent_protocol.py index ba75358ab..1f0401439 100644 --- a/tests/agent_unittests/test_agent_protocol.py +++ b/tests/agent_unittests/test_agent_protocol.py @@ -565,6 +565,7 @@ def test_ca_bundle_path(monkeypatch, ca_bundle_path): # Pretend CA certificates are not available class DefaultVerifyPaths(object): cafile = None + capath = None def __init__(self, *args, **kwargs): pass diff --git a/tests/agent_unittests/test_http_client.py b/tests/agent_unittests/test_http_client.py index a5c340d6a..df409f932 100644 --- a/tests/agent_unittests/test_http_client.py +++ b/tests/agent_unittests/test_http_client.py @@ -325,7 +325,7 @@ def test_http_payload_compression(server, client_cls, method, threshold): # Verify the compressed payload length is recorded assert internal_metrics["Supportability/Python/Collector/method1/ZLIB/Bytes"][:2] == [1, payload_byte_len] assert internal_metrics["Supportability/Python/Collector/ZLIB/Bytes"][:2] == [2, payload_byte_len*2] - + assert len(internal_metrics) == 8 else: # Verify no ZLIB compression metrics were sent @@ -366,11 +366,14 @@ def test_cert_path(server): def test_default_cert_path(monkeypatch, system_certs_available): if system_certs_available: cert_file = "foo" + ca_path = "/usr/certs" else: cert_file = None + ca_path = None class DefaultVerifyPaths(object): cafile = cert_file + capath = ca_path def __init__(self, *args, **kwargs): pass diff --git a/tests/agent_unittests/test_package_version_utils.py b/tests/agent_unittests/test_package_version_utils.py index 435d74947..30c22cff1 100644 --- a/tests/agent_unittests/test_package_version_utils.py +++ b/tests/agent_unittests/test_package_version_utils.py @@ -24,11 +24,19 @@ get_package_version_tuple, ) +# Notes: +# importlib.metadata was a provisional addition to the std library in PY38 and PY39 +# while pkg_resources was deprecated. +# importlib.metadata is no longer provisional in PY310+. It added some attributes +# such as distribution_packages and removed pkg_resources. + IS_PY38_PLUS = sys.version_info[:2] >= (3, 8) +IS_PY310_PLUS = sys.version_info[:2] >= (3,10) SKIP_IF_NOT_IMPORTLIB_METADATA = pytest.mark.skipif(not IS_PY38_PLUS, reason="importlib.metadata is not supported.") SKIP_IF_IMPORTLIB_METADATA = pytest.mark.skipif( IS_PY38_PLUS, reason="importlib.metadata is preferred over pkg_resources." ) +SKIP_IF_NOT_PY310_PLUS = pytest.mark.skipif(not IS_PY310_PLUS, reason="These features were added in 3.10+") @pytest.fixture(scope="function", autouse=True) @@ -38,8 +46,10 @@ def patched_pytest_module(monkeypatch): monkeypatch.delattr(pytest, attr) yield pytest + - +# This test only works on Python 3.7 +@SKIP_IF_IMPORTLIB_METADATA @pytest.mark.parametrize( "attr,value,expected_value", ( @@ -58,6 +68,8 @@ def test_get_package_version(attr, value, expected_value): delattr(pytest, attr) +# This test only works on Python 3.7 +@SKIP_IF_IMPORTLIB_METADATA def test_skips_version_callables(): # There is no file/module here, so we monkeypatch # pytest instead for our purposes @@ -72,6 +84,8 @@ def test_skips_version_callables(): delattr(pytest, "version_tuple") +# This test only works on Python 3.7 +@SKIP_IF_IMPORTLIB_METADATA @pytest.mark.parametrize( "attr,value,expected_value", ( @@ -97,6 +111,13 @@ def test_importlib_metadata(): assert version not in NULL_VERSIONS, version +@SKIP_IF_NOT_PY310_PLUS +@validate_function_called("importlib.metadata", "packages_distributions") +def test_mapping_import_to_distribution_packages(): + version = get_package_version("pytest") + assert version not in NULL_VERSIONS, version + + @SKIP_IF_IMPORTLIB_METADATA @validate_function_called("pkg_resources", "get_distribution") def test_pkg_resources_metadata(): diff --git a/tests/component_djangorestframework/test_application.py b/tests/component_djangorestframework/test_application.py index 9ed60aa33..29861dca8 100644 --- a/tests/component_djangorestframework/test_application.py +++ b/tests/component_djangorestframework/test_application.py @@ -12,190 +12,168 @@ # See the License for the specific language governing permissions and # limitations under the License. +import django import pytest import webtest +from testing_support.fixtures import function_not_called, override_generic_settings +from testing_support.validators.validate_code_level_metrics import ( + validate_code_level_metrics, +) +from testing_support.validators.validate_transaction_errors import ( + validate_transaction_errors, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) -from newrelic.packages import six from newrelic.core.config import global_settings +from newrelic.packages import six -from testing_support.fixtures import ( - override_generic_settings, - function_not_called) -from testing_support.validators.validate_transaction_errors import validate_transaction_errors -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_code_level_metrics import validate_code_level_metrics -import django - -DJANGO_VERSION = tuple(map(int, django.get_version().split('.')[:2])) - +DJANGO_VERSION = tuple(map(int, django.get_version().split(".")[:2])) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def target_application(): from wsgi import application + test_application = webtest.TestApp(application) return test_application if DJANGO_VERSION >= (1, 10): - url_module_path = 'django.urls.resolvers' + url_module_path = "django.urls.resolvers" # Django 1.10 new style middleware removed individual process_* methods. # All middleware in Django 1.10+ is called through the __call__ methods on # middlwares. - process_request_method = '' - process_view_method = '' - process_response_method = '' + process_request_method = "" + process_view_method = "" + process_response_method = "" else: - url_module_path = 'django.core.urlresolvers' - process_request_method = '.process_request' - process_view_method = '.process_view' - process_response_method = '.process_response' + url_module_path = "django.core.urlresolvers" + process_request_method = ".process_request" + process_view_method = ".process_view" + process_response_method = ".process_response" if DJANGO_VERSION >= (2, 0): - url_resolver_cls = 'URLResolver' + url_resolver_cls = "URLResolver" else: - url_resolver_cls = 'RegexURLResolver' + url_resolver_cls = "RegexURLResolver" _scoped_metrics = [ - ('Function/django.core.handlers.wsgi:WSGIHandler.__call__', 1), - ('Python/WSGI/Application', 1), - ('Python/WSGI/Response', 1), - ('Python/WSGI/Finalize', 1), - (('Function/django.middleware.common:' - 'CommonMiddleware' + process_request_method), 1), - (('Function/django.contrib.sessions.middleware:' - 'SessionMiddleware' + process_request_method), 1), - (('Function/django.contrib.auth.middleware:' - 'AuthenticationMiddleware' + process_request_method), 1), - (('Function/django.contrib.messages.middleware:' - 'MessageMiddleware' + process_request_method), 1), - (('Function/%s:' % url_module_path + - '%s.resolve' % url_resolver_cls), 1), - (('Function/django.middleware.csrf:' - 'CsrfViewMiddleware' + process_view_method), 1), - (('Function/django.contrib.messages.middleware:' - 'MessageMiddleware' + process_response_method), 1), - (('Function/django.middleware.csrf:' - 'CsrfViewMiddleware' + process_response_method), 1), - (('Function/django.contrib.sessions.middleware:' - 'SessionMiddleware' + process_response_method), 1), - (('Function/django.middleware.common:' - 'CommonMiddleware' + process_response_method), 1), + ("Function/django.core.handlers.wsgi:WSGIHandler.__call__", 1), + ("Python/WSGI/Application", 1), + ("Python/WSGI/Response", 1), + ("Python/WSGI/Finalize", 1), + (("Function/django.middleware.common:CommonMiddleware%s" % process_request_method), 1), + (("Function/django.contrib.sessions.middleware:SessionMiddleware%s" % process_request_method), 1), + (("Function/django.contrib.auth.middleware:AuthenticationMiddleware%s" % process_request_method), 1), + (("Function/django.contrib.messages.middleware:MessageMiddleware%s" % process_request_method), 1), + (("Function/%s:%s.resolve" % (url_module_path, url_resolver_cls)), 1), + (("Function/django.middleware.csrf:CsrfViewMiddleware%s" % process_view_method), 1), + (("Function/django.contrib.messages.middleware:MessageMiddleware%s" % process_response_method), 1), + (("Function/django.middleware.csrf:CsrfViewMiddleware%s" % process_response_method), 1), + (("Function/django.contrib.sessions.middleware:SessionMiddleware%s" % process_response_method), 1), + (("Function/django.middleware.common:CommonMiddleware%s" % process_response_method), 1), ] _test_application_index_scoped_metrics = list(_scoped_metrics) -_test_application_index_scoped_metrics.append(('Function/views:index', 1)) +_test_application_index_scoped_metrics.append(("Function/views:index", 1)) if DJANGO_VERSION >= (1, 5): - _test_application_index_scoped_metrics.extend([ - ('Function/django.http.response:HttpResponse.close', 1)]) + _test_application_index_scoped_metrics.extend([("Function/django.http.response:HttpResponse.close", 1)]) @validate_transaction_errors(errors=[]) -@validate_transaction_metrics('views:index', - scoped_metrics=_test_application_index_scoped_metrics) +@validate_transaction_metrics("views:index", scoped_metrics=_test_application_index_scoped_metrics) @validate_code_level_metrics("views", "index") def test_application_index(target_application): - response = target_application.get('') - response.mustcontain('INDEX RESPONSE') + response = target_application.get("") + response.mustcontain("INDEX RESPONSE") _test_application_view_scoped_metrics = list(_scoped_metrics) -_test_application_view_scoped_metrics.append(('Function/urls:View.get', 1)) +_test_application_view_scoped_metrics.append(("Function/urls:View.get", 1)) if DJANGO_VERSION >= (1, 5): - _test_application_view_scoped_metrics.extend([ - ('Function/rest_framework.response:Response.close', 1)]) + _test_application_view_scoped_metrics.extend([("Function/rest_framework.response:Response.close", 1)]) @validate_transaction_errors(errors=[]) -@validate_transaction_metrics('urls:View.get', - scoped_metrics=_test_application_view_scoped_metrics) +@validate_transaction_metrics("urls:View.get", scoped_metrics=_test_application_view_scoped_metrics) @validate_code_level_metrics("urls.View", "get") def test_application_view(target_application): - response = target_application.get('/view/') + response = target_application.get("/view/") assert response.status_int == 200 - response.mustcontain('restframework view response') + response.mustcontain("restframework view response") _test_application_view_error_scoped_metrics = list(_scoped_metrics) -_test_application_view_error_scoped_metrics.append( - ('Function/urls:ViewError.get', 1)) +_test_application_view_error_scoped_metrics.append(("Function/urls:ViewError.get", 1)) -@validate_transaction_errors(errors=['urls:Error']) -@validate_transaction_metrics('urls:ViewError.get', - scoped_metrics=_test_application_view_error_scoped_metrics) +@validate_transaction_errors(errors=["urls:Error"]) +@validate_transaction_metrics("urls:ViewError.get", scoped_metrics=_test_application_view_error_scoped_metrics) @validate_code_level_metrics("urls.ViewError", "get") def test_application_view_error(target_application): - target_application.get('/view_error/', status=500) + target_application.get("/view_error/", status=500) _test_application_view_handle_error_scoped_metrics = list(_scoped_metrics) -_test_application_view_handle_error_scoped_metrics.append( - ('Function/urls:ViewHandleError.get', 1)) +_test_application_view_handle_error_scoped_metrics.append(("Function/urls:ViewHandleError.get", 1)) -@pytest.mark.parametrize('status,should_record', [(418, True), (200, False)]) -@pytest.mark.parametrize('use_global_exc_handler', [True, False]) +@pytest.mark.parametrize("status,should_record", [(418, True), (200, False)]) +@pytest.mark.parametrize("use_global_exc_handler", [True, False]) @validate_code_level_metrics("urls.ViewHandleError", "get") -def test_application_view_handle_error(status, should_record, - use_global_exc_handler, target_application): - errors = ['urls:Error'] if should_record else [] +def test_application_view_handle_error(status, should_record, use_global_exc_handler, target_application): + errors = ["urls:Error"] if should_record else [] @validate_transaction_errors(errors=errors) - @validate_transaction_metrics('urls:ViewHandleError.get', - scoped_metrics=_test_application_view_handle_error_scoped_metrics) + @validate_transaction_metrics( + "urls:ViewHandleError.get", scoped_metrics=_test_application_view_handle_error_scoped_metrics + ) def _test(): - response = target_application.get( - '/view_handle_error/%s/%s/' % (status, use_global_exc_handler), - status=status) + response = target_application.get("/view_handle_error/%s/%s/" % (status, use_global_exc_handler), status=status) if use_global_exc_handler: - response.mustcontain('exception was handled global') + response.mustcontain("exception was handled global") else: - response.mustcontain('exception was handled not global') + response.mustcontain("exception was handled not global") _test() -_test_api_view_view_name_get = 'urls:wrapped_view.get' +_test_api_view_view_name_get = "urls:wrapped_view.get" _test_api_view_scoped_metrics_get = list(_scoped_metrics) -_test_api_view_scoped_metrics_get.append( - ('Function/%s' % _test_api_view_view_name_get, 1)) +_test_api_view_scoped_metrics_get.append(("Function/%s" % _test_api_view_view_name_get, 1)) @validate_transaction_errors(errors=[]) -@validate_transaction_metrics(_test_api_view_view_name_get, - scoped_metrics=_test_api_view_scoped_metrics_get) -@validate_code_level_metrics("urls.WrappedAPIView" if six.PY3 else "urls", "wrapped_view") +@validate_transaction_metrics(_test_api_view_view_name_get, scoped_metrics=_test_api_view_scoped_metrics_get) +@validate_code_level_metrics("urls.WrappedAPIView", "wrapped_view", py2_namespace="urls") def test_api_view_get(target_application): - response = target_application.get('/api_view/') - response.mustcontain('wrapped_view response') + response = target_application.get("/api_view/") + response.mustcontain("wrapped_view response") -_test_api_view_view_name_post = 'urls:wrapped_view.http_method_not_allowed' +_test_api_view_view_name_post = "urls:wrapped_view.http_method_not_allowed" _test_api_view_scoped_metrics_post = list(_scoped_metrics) -_test_api_view_scoped_metrics_post.append( - ('Function/%s' % _test_api_view_view_name_post, 1)) +_test_api_view_scoped_metrics_post.append(("Function/%s" % _test_api_view_view_name_post, 1)) -@validate_transaction_errors( - errors=['rest_framework.exceptions:MethodNotAllowed']) -@validate_transaction_metrics(_test_api_view_view_name_post, - scoped_metrics=_test_api_view_scoped_metrics_post) +@validate_transaction_errors(errors=["rest_framework.exceptions:MethodNotAllowed"]) +@validate_transaction_metrics(_test_api_view_view_name_post, scoped_metrics=_test_api_view_scoped_metrics_post) def test_api_view_method_not_allowed(target_application): - target_application.post('/api_view/', status=405) + target_application.post("/api_view/", status=405) def test_application_view_agent_disabled(target_application): settings = global_settings() - @override_generic_settings(settings, {'enabled': False}) - @function_not_called('newrelic.core.stats_engine', - 'StatsEngine.record_transaction') + @override_generic_settings(settings, {"enabled": False}) + @function_not_called("newrelic.core.stats_engine", "StatsEngine.record_transaction") def _test(): - response = target_application.get('/view/') + response = target_application.get("/view/") assert response.status_int == 200 - response.mustcontain('restframework view response') + response.mustcontain("restframework view response") _test() diff --git a/tests/component_flask_rest/test_application.py b/tests/component_flask_rest/test_application.py index d463a0205..67d4825a1 100644 --- a/tests/component_flask_rest/test_application.py +++ b/tests/component_flask_rest/test_application.py @@ -31,8 +31,6 @@ from newrelic.core.config import global_settings from newrelic.packages import six -TEST_APPLICATION_PREFIX = "_test_application.create_app." if six.PY3 else "_test_application" - @pytest.fixture(params=["flask_restful", "flask_restx"]) def application(request): @@ -62,7 +60,7 @@ def application(request): ] -@validate_code_level_metrics(TEST_APPLICATION_PREFIX + ".IndexResource", "get") +@validate_code_level_metrics("_test_application.create_app..IndexResource", "get", py2_namespace="_test_application.IndexResource") @validate_transaction_errors(errors=[]) @validate_transaction_metrics("_test_application:index", scoped_metrics=_test_application_index_scoped_metrics) def test_application_index(application): @@ -88,7 +86,7 @@ def test_application_index(application): ], ) def test_application_raises(exception, status_code, ignore_status_code, propagate_exceptions, application): - @validate_code_level_metrics(TEST_APPLICATION_PREFIX + ".ExceptionResource", "get") + @validate_code_level_metrics("_test_application.create_app..ExceptionResource", "get", py2_namespace="_test_application.ExceptionResource") @validate_transaction_metrics("_test_application:exception", scoped_metrics=_test_application_raises_scoped_metrics) def _test(): try: @@ -118,4 +116,4 @@ def test_application_outside_transaction(application): def _test(): application.get("/exception/werkzeug.exceptions:HTTPException/404", status=404) - _test() + _test() \ No newline at end of file diff --git a/tests/component_graphqlserver/__init__.py b/tests/component_graphqlserver/__init__.py new file mode 100644 index 000000000..8030baccf --- /dev/null +++ b/tests/component_graphqlserver/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/component_graphqlserver/_target_schema_async.py b/tests/component_graphqlserver/_target_schema_async.py new file mode 100644 index 000000000..aff587bc8 --- /dev/null +++ b/tests/component_graphqlserver/_target_schema_async.py @@ -0,0 +1,155 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from graphql import ( + GraphQLArgument, + GraphQLField, + GraphQLInt, + GraphQLList, + GraphQLNonNull, + GraphQLObjectType, + GraphQLSchema, + GraphQLString, + GraphQLUnionType, +) + +from ._target_schema_sync import books, libraries, magazines + +storage = [] + + +async def resolve_library(parent, info, index): + return libraries[index] + + +async def resolve_storage_add(parent, info, string): + storage.append(string) + return string + + +async def resolve_storage(parent, info): + return [storage.pop()] + + +async def resolve_search(parent, info, contains): + search_books = [b for b in books if contains in b["name"]] + search_magazines = [m for m in magazines if contains in m["name"]] + return search_books + search_magazines + + +Author = GraphQLObjectType( + "Author", + { + "first_name": GraphQLField(GraphQLString), + "last_name": GraphQLField(GraphQLString), + }, +) + +Book = GraphQLObjectType( + "Book", + { + "id": GraphQLField(GraphQLInt), + "name": GraphQLField(GraphQLString), + "isbn": GraphQLField(GraphQLString), + "author": GraphQLField(Author), + "branch": GraphQLField(GraphQLString), + }, +) + +Magazine = GraphQLObjectType( + "Magazine", + { + "id": GraphQLField(GraphQLInt), + "name": GraphQLField(GraphQLString), + "issue": GraphQLField(GraphQLInt), + "branch": GraphQLField(GraphQLString), + }, +) + + +Library = GraphQLObjectType( + "Library", + { + "id": GraphQLField(GraphQLInt), + "branch": GraphQLField(GraphQLString), + "book": GraphQLField(GraphQLList(Book)), + "magazine": GraphQLField(GraphQLList(Magazine)), + }, +) + +Storage = GraphQLList(GraphQLString) + + +async def resolve_hello(root, info): + return "Hello!" + + +async def resolve_echo(root, info, echo): + return echo + + +async def resolve_error(root, info): + raise RuntimeError("Runtime Error!") + + +hello_field = GraphQLField(GraphQLString, resolver=resolve_hello) +library_field = GraphQLField( + Library, + resolver=resolve_library, + args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, +) +search_field = GraphQLField( + GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), + args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +echo_field = GraphQLField( + GraphQLString, + resolver=resolve_echo, + args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +storage_field = GraphQLField( + Storage, + resolver=resolve_storage, +) +storage_add_field = GraphQLField( + GraphQLString, + resolver=resolve_storage_add, + args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +error_field = GraphQLField(GraphQLString, resolver=resolve_error) +error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolver=resolve_error) +error_middleware_field = GraphQLField(GraphQLString, resolver=resolve_hello) + +query = GraphQLObjectType( + name="Query", + fields={ + "hello": hello_field, + "library": library_field, + "search": search_field, + "echo": echo_field, + "storage": storage_field, + "error": error_field, + "error_non_null": error_non_null_field, + "error_middleware": error_middleware_field, + }, +) + +mutation = GraphQLObjectType( + name="Mutation", + fields={ + "storage_add": storage_add_field, + }, +) + +target_schema = GraphQLSchema(query=query, mutation=mutation) diff --git a/tests/component_graphqlserver/_test_graphql.py b/tests/component_graphqlserver/_test_graphql.py index 50b5621f9..7a29b3a8f 100644 --- a/tests/component_graphqlserver/_test_graphql.py +++ b/tests/component_graphqlserver/_test_graphql.py @@ -12,15 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from flask import Flask +from sanic import Sanic import json - import webtest -from flask import Flask -from framework_graphql._target_application import _target_application as schema + +from testing_support.asgi_testing import AsgiTest +from framework_graphql._target_schema_sync import target_schema as schema from graphql_server.flask import GraphQLView as FlaskView from graphql_server.sanic import GraphQLView as SanicView -from sanic import Sanic -from testing_support.asgi_testing import AsgiTest + +# Sanic +target_application = dict() def set_middlware(middleware, view_middleware): @@ -95,5 +98,4 @@ def flask_execute(query, middleware=None): return response - target_application["Flask"] = flask_execute diff --git a/tests/component_graphqlserver/test_graphql.py b/tests/component_graphqlserver/test_graphql.py index e5566047e..098f50970 100644 --- a/tests/component_graphqlserver/test_graphql.py +++ b/tests/component_graphqlserver/test_graphql.py @@ -12,16 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. + import importlib import pytest from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_errors import validate_transaction_errors -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics from testing_support.validators.validate_span_events import validate_span_events from testing_support.validators.validate_transaction_count import ( validate_transaction_count, ) +from testing_support.validators.validate_transaction_errors import ( + validate_transaction_errors, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.common.object_names import callable_name @@ -36,7 +41,7 @@ def is_graphql_2(): @pytest.fixture(scope="session", params=("Sanic", "Flask")) def target_application(request): - import _test_graphql + from . import _test_graphql framework = request.param version = importlib.import_module(framework.lower()).__version__ @@ -186,7 +191,7 @@ def test_middleware(target_application): _test_middleware_metrics = [ ("GraphQL/operation/GraphQLServer/query//hello", 1), ("GraphQL/resolve/GraphQLServer/hello", 1), - ("Function/test_graphql:example_middleware", 1), + ("Function/component_graphqlserver.test_graphql:example_middleware", 1), ] # Base span count 6: Transaction, View, Operation, Middleware, and 1 Resolver and Resolver function @@ -220,7 +225,7 @@ def test_exception_in_middleware(target_application): _test_exception_rollup_metrics = [ ("Errors/all", 1), ("Errors/allWeb", 1), - ("Errors/WebTransaction/GraphQL/test_graphql:error_middleware", 1), + ("Errors/WebTransaction/GraphQL/component_graphqlserver.test_graphql:error_middleware", 1), ] + _test_exception_scoped_metrics # Attributes @@ -237,7 +242,7 @@ def test_exception_in_middleware(target_application): } @validate_transaction_metrics( - "test_graphql:error_middleware", + "component_graphqlserver.test_graphql:error_middleware", "GraphQL", scoped_metrics=_test_exception_scoped_metrics, rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, @@ -257,7 +262,7 @@ def test_exception_in_resolver(target_application, field): framework, version, target_application = target_application query = "query MyQuery { %s }" % field - txn_name = "framework_graphql._target_application:resolve_error" + txn_name = "framework_graphql._target_schema_sync:resolve_error" # Metrics _test_exception_scoped_metrics = [ @@ -488,7 +493,7 @@ def _test(): def test_deepest_unique_path(target_application, query, expected_path): framework, version, target_application = target_application if expected_path == "/error": - txn_name = "framework_graphql._target_application:resolve_error" + txn_name = "framework_graphql._target_schema_sync:resolve_error" else: txn_name = "query/%s" % expected_path diff --git a/tests/cross_agent/test_agent_attributes.py b/tests/cross_agent/test_agent_attributes.py index c254be772..527b31a75 100644 --- a/tests/cross_agent/test_agent_attributes.py +++ b/tests/cross_agent/test_agent_attributes.py @@ -40,7 +40,8 @@ def _default_settings(): 'browser_monitoring.attributes.exclude': [], } -FIXTURE = os.path.join(os.curdir, 'fixtures', 'attribute_configuration.json') +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +FIXTURE = os.path.join(CURRENT_DIR, 'fixtures', 'attribute_configuration.json') def _load_tests(): with open(FIXTURE, 'r') as fh: diff --git a/tests/cross_agent/test_datstore_instance.py b/tests/cross_agent/test_datastore_instance.py similarity index 52% rename from tests/cross_agent/test_datstore_instance.py rename to tests/cross_agent/test_datastore_instance.py index aa095400f..79a95e0be 100644 --- a/tests/cross_agent/test_datstore_instance.py +++ b/tests/cross_agent/test_datastore_instance.py @@ -14,34 +14,40 @@ import json import os + import pytest from newrelic.api.background_task import background_task -from newrelic.api.database_trace import (register_database_client, - enable_datastore_instance_feature) +from newrelic.api.database_trace import register_database_client from newrelic.api.transaction import current_transaction from newrelic.core.database_node import DatabaseNode from newrelic.core.stats_engine import StatsEngine -FIXTURE = os.path.join(os.curdir, - 'fixtures', 'datastores', 'datastore_instances.json') +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +FIXTURE = os.path.join(CURRENT_DIR, "fixtures", "datastores", "datastore_instances.json") -_parameters_list = ['name', 'system_hostname', 'db_hostname', - 'product', 'port', 'unix_socket', 'database_path', - 'expected_instance_metric'] +_parameters_list = [ + "name", + "system_hostname", + "db_hostname", + "product", + "port", + "unix_socket", + "database_path", + "expected_instance_metric", +] -_parameters = ','.join(_parameters_list) +_parameters = ",".join(_parameters_list) def _load_tests(): - with open(FIXTURE, 'r') as fh: + with open(FIXTURE, "r") as fh: js = fh.read() return json.loads(js) def _parametrize_test(test): - return tuple([test.get(f, None if f != 'db_hostname' else 'localhost') - for f in _parameters_list]) + return tuple([test.get(f, None if f != "db_hostname" else "localhost") for f in _parameters_list]) _datastore_tests = [_parametrize_test(t) for t in _load_tests()] @@ -49,45 +55,44 @@ def _parametrize_test(test): @pytest.mark.parametrize(_parameters, _datastore_tests) @background_task() -def test_datastore_instance(name, system_hostname, db_hostname, - product, port, unix_socket, database_path, - expected_instance_metric, monkeypatch): +def test_datastore_instance( + name, system_hostname, db_hostname, product, port, unix_socket, database_path, expected_instance_metric, monkeypatch +): - monkeypatch.setattr('newrelic.common.system_info.gethostname', - lambda: system_hostname) + monkeypatch.setattr("newrelic.common.system_info.gethostname", lambda: system_hostname) - class FakeModule(): + class FakeModule: pass register_database_client(FakeModule, product) - enable_datastore_instance_feature(FakeModule) port_path_or_id = port or database_path or unix_socket - node = DatabaseNode(dbapi2_module=FakeModule, - sql='', - children=[], - start_time=0, - end_time=1, - duration=1, - exclusive=1, - stack_trace=None, - sql_format='obfuscated', - connect_params=None, - cursor_params=None, - sql_parameters=None, - execute_params=None, - host=db_hostname, - port_path_or_id=port_path_or_id, - database_name=database_path, - guid=None, - agent_attributes={}, - user_attributes={}, + node = DatabaseNode( + dbapi2_module=FakeModule, + sql="", + children=[], + start_time=0, + end_time=1, + duration=1, + exclusive=1, + stack_trace=None, + sql_format="obfuscated", + connect_params=None, + cursor_params=None, + sql_parameters=None, + execute_params=None, + host=db_hostname, + port_path_or_id=port_path_or_id, + database_name=database_path, + guid=None, + agent_attributes={}, + user_attributes={}, ) empty_stats = StatsEngine() transaction = current_transaction() - unscoped_scope = '' + unscoped_scope = "" # Check 'Datastore/instance' metric to confirm that: # 1. metric name is reported correctly diff --git a/tests/cross_agent/test_docker.py b/tests/cross_agent/test_docker.py index 9bc1a7363..fd919932b 100644 --- a/tests/cross_agent/test_docker.py +++ b/tests/cross_agent/test_docker.py @@ -19,7 +19,8 @@ import newrelic.common.utilization as u -DOCKER_FIXTURE = os.path.join(os.curdir, 'fixtures', 'docker_container_id') +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +DOCKER_FIXTURE = os.path.join(CURRENT_DIR, 'fixtures', 'docker_container_id') def _load_docker_test_attributes(): diff --git a/tests/cross_agent/test_labels_and_rollups.py b/tests/cross_agent/test_labels_and_rollups.py index d333ec35b..15ebb1e36 100644 --- a/tests/cross_agent/test_labels_and_rollups.py +++ b/tests/cross_agent/test_labels_and_rollups.py @@ -21,7 +21,8 @@ from testing_support.fixtures import override_application_settings -FIXTURE = os.path.join(os.curdir, 'fixtures', 'labels.json') +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +FIXTURE = os.path.join(CURRENT_DIR, 'fixtures', 'labels.json') def _load_tests(): with open(FIXTURE, 'r') as fh: diff --git a/tests/cross_agent/test_rules.py b/tests/cross_agent/test_rules.py index e37db787c..ce2983c90 100644 --- a/tests/cross_agent/test_rules.py +++ b/tests/cross_agent/test_rules.py @@ -16,23 +16,23 @@ import os import pytest -from newrelic.core.rules_engine import RulesEngine, NormalizationRule +from newrelic.api.application import application_instance +from newrelic.api.background_task import background_task +from newrelic.api.transaction import record_custom_metric +from newrelic.core.rules_engine import RulesEngine + +from testing_support.validators.validate_metric_payload import validate_metric_payload CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) FIXTURE = os.path.normpath(os.path.join( CURRENT_DIR, 'fixtures', 'rules.json')) + def _load_tests(): with open(FIXTURE, 'r') as fh: js = fh.read() return json.loads(js) -def _prepare_rules(test_rules): - # ensure all keys are present, if not present set to an empty string - for rule in test_rules: - for key in NormalizationRule._fields: - rule[key] = rule.get(key, '') - return test_rules def _make_case_insensitive(rules): # lowercase each rule @@ -42,14 +42,14 @@ def _make_case_insensitive(rules): rule['replacement'] = rule['replacement'].lower() return rules + @pytest.mark.parametrize('test_group', _load_tests()) def test_rules_engine(test_group): # FIXME: The test fixture assumes that matching is case insensitive when it # is not. To avoid errors, just lowercase all rules, inputs, and expected # values. - insense_rules = _make_case_insensitive(test_group['rules']) - test_rules = _prepare_rules(insense_rules) + test_rules = _make_case_insensitive(test_group['rules']) rules_engine = RulesEngine(test_rules) for test in test_group['tests']: @@ -66,3 +66,46 @@ def test_rules_engine(test_group): assert expected == '' else: assert result == expected + + +@pytest.mark.parametrize('test_group', _load_tests()) +def test_rules_engine_metric_harvest(test_group): + # FIXME: The test fixture assumes that matching is case insensitive when it + # is not. To avoid errors, just lowercase all rules, inputs, and expected + # values. + test_rules = _make_case_insensitive(test_group['rules']) + rules_engine = RulesEngine(test_rules) + + # Set rules engine on core application + api_application = application_instance(activate=False) + api_name = api_application.name + core_application = api_application._agent.application(api_name) + old_rules = core_application._rules_engine["metric"] # save previoius rules + core_application._rules_engine["metric"] = rules_engine + + def send_metrics(): + # Send all metrics in this test batch in one transaction, then harvest so the normalizer is run. + @background_task(name="send_metrics") + def _test(): + for test in test_group['tests']: + # lowercase each value + input_str = test['input'].lower() + record_custom_metric(input_str, {"count": 1}) + _test() + core_application.harvest() + + try: + # Create a map of all result metrics to validate after harvest + test_metrics = [] + for test in test_group['tests']: + expected = (test['expected'] or '').lower() + if expected == '': # Ignored + test_metrics.append((expected, None)) + else: + test_metrics.append((expected, 1)) + + # Harvest and validate resulting payload + validate_metric_payload(metrics=test_metrics)(send_metrics)() + finally: + # Replace original rules engine + core_application._rules_engine["metric"] = old_rules diff --git a/tests/cross_agent/test_rum_client_config.py b/tests/cross_agent/test_rum_client_config.py index c2a4a465f..5b8da4b84 100644 --- a/tests/cross_agent/test_rum_client_config.py +++ b/tests/cross_agent/test_rum_client_config.py @@ -26,10 +26,11 @@ ) from newrelic.api.wsgi_application import wsgi_application +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +FIXTURE = os.path.join(CURRENT_DIR, "fixtures", "rum_client_config.json") def _load_tests(): - fixture = os.path.join(os.curdir, "fixtures", "rum_client_config.json") - with open(fixture, "r") as fh: + with open(FIXTURE, "r") as fh: js = fh.read() return json.loads(js) diff --git a/tests/datastore_bmemcached/test_memcache.py b/tests/datastore_bmemcached/test_memcache.py index 68eee0633..2f87da113 100644 --- a/tests/datastore_bmemcached/test_memcache.py +++ b/tests/datastore_bmemcached/test_memcache.py @@ -13,83 +13,94 @@ # limitations under the License. import os -from testing_support.db_settings import memcached_settings + import bmemcached +from testing_support.db_settings import memcached_settings +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.api.background_task import background_task from newrelic.api.transaction import set_background_task from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.db_settings import memcached_settings - DB_SETTINGS = memcached_settings()[0] -MEMCACHED_HOST = DB_SETTINGS['host'] -MEMCACHED_PORT = DB_SETTINGS['port'] +MEMCACHED_HOST = DB_SETTINGS["host"] +MEMCACHED_PORT = DB_SETTINGS["port"] MEMCACHED_NAMESPACE = str(os.getpid()) -MEMCACHED_ADDR = '%s:%s' % (MEMCACHED_HOST, MEMCACHED_PORT) +MEMCACHED_ADDR = "%s:%s" % (MEMCACHED_HOST, MEMCACHED_PORT) _test_bt_set_get_delete_scoped_metrics = [ - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] _test_bt_set_get_delete_rollup_metrics = [ - ('Datastore/all', 3), - ('Datastore/allOther', 3), - ('Datastore/Memcached/all', 3), - ('Datastore/Memcached/allOther', 3), - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/all", 3), + ("Datastore/allOther", 3), + ("Datastore/Memcached/all", 3), + ("Datastore/Memcached/allOther", 3), + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] + @validate_transaction_metrics( - 'test_memcache:test_bt_set_get_delete', - scoped_metrics=_test_bt_set_get_delete_scoped_metrics, - rollup_metrics=_test_bt_set_get_delete_rollup_metrics, - background_task=True) + "test_memcache:test_bt_set_get_delete", + scoped_metrics=_test_bt_set_get_delete_scoped_metrics, + rollup_metrics=_test_bt_set_get_delete_rollup_metrics, + background_task=True, +) @background_task() def test_bt_set_get_delete(): set_background_task(True) client = bmemcached.Client([MEMCACHED_ADDR]) - key = MEMCACHED_NAMESPACE + 'key' + key = MEMCACHED_NAMESPACE + "key" - client.set(key, 'value') + client.set(key, "value") value = client.get(key) client.delete(key) - assert value == 'value' + assert value == "value" + _test_wt_set_get_delete_scoped_metrics = [ - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] _test_wt_set_get_delete_rollup_metrics = [ - ('Datastore/all', 3), - ('Datastore/allWeb', 3), - ('Datastore/Memcached/all', 3), - ('Datastore/Memcached/allWeb', 3), - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/all", 3), + ("Datastore/allWeb", 3), + ("Datastore/Memcached/all", 3), + ("Datastore/Memcached/allWeb", 3), + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] + @validate_transaction_metrics( - 'test_memcache:test_wt_set_get_delete', - scoped_metrics=_test_wt_set_get_delete_scoped_metrics, - rollup_metrics=_test_wt_set_get_delete_rollup_metrics, - background_task=False) + "test_memcache:test_wt_set_get_delete", + scoped_metrics=_test_wt_set_get_delete_scoped_metrics, + rollup_metrics=_test_wt_set_get_delete_rollup_metrics, + background_task=False, +) @background_task() def test_wt_set_get_delete(): set_background_task(False) client = bmemcached.Client([MEMCACHED_ADDR]) - key = MEMCACHED_NAMESPACE + 'key' + key = MEMCACHED_NAMESPACE + "key" - client.set(key, 'value') + client.set(key, "value") value = client.get(key) client.delete(key) - assert value == 'value' + assert value == "value" diff --git a/tests/datastore_firestore/conftest.py b/tests/datastore_firestore/conftest.py new file mode 100644 index 000000000..28e138fa2 --- /dev/null +++ b/tests/datastore_firestore/conftest.py @@ -0,0 +1,124 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import uuid + +import pytest + +from google.cloud.firestore import Client +from google.cloud.firestore import Client, AsyncClient + +from testing_support.db_settings import firestore_settings +from testing_support.fixture.event_loop import event_loop as loop # noqa: F401; pylint: disable=W0611 +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) + +from newrelic.api.datastore_trace import DatastoreTrace +from newrelic.api.time_trace import current_trace +from newrelic.common.system_info import LOCALHOST_EQUIVALENTS, gethostname + +DB_SETTINGS = firestore_settings()[0] +FIRESTORE_HOST = DB_SETTINGS["host"] +FIRESTORE_PORT = DB_SETTINGS["port"] + +_default_settings = { + "transaction_tracer.explain_threshold": 0.0, + "transaction_tracer.transaction_threshold": 0.0, + "transaction_tracer.stack_trace_threshold": 0.0, + "debug.log_data_collector_payloads": True, + "debug.record_transaction_failure": True, + "debug.log_explain_plan_queries": True, +} + +collector_agent_registration = collector_agent_registration_fixture( + app_name="Python Agent Test (datastore_firestore)", + default_settings=_default_settings, + linked_applications=["Python Agent Test (datastore)"], +) + + +@pytest.fixture() +def instance_info(): + host = gethostname() if FIRESTORE_HOST in LOCALHOST_EQUIVALENTS else FIRESTORE_HOST + return {"host": host, "port_path_or_id": str(FIRESTORE_PORT), "db.instance": "projects/google-cloud-firestore-emulator/databases/(default)"} + + +@pytest.fixture(scope="session") +def client(): + os.environ["FIRESTORE_EMULATOR_HOST"] = "%s:%d" % (FIRESTORE_HOST, FIRESTORE_PORT) + client = Client() + # Ensure connection is available + client.collection("healthcheck").document("healthcheck").set( + {}, retry=None, timeout=5 + ) + return client + + +@pytest.fixture(scope="function") +def collection(client): + collection_ = client.collection("firestore_collection_" + str(uuid.uuid4())) + yield collection_ + client.recursive_delete(collection_) + + +@pytest.fixture(scope="session") +def async_client(loop): + os.environ["FIRESTORE_EMULATOR_HOST"] = "%s:%d" % (FIRESTORE_HOST, FIRESTORE_PORT) + client = AsyncClient() + loop.run_until_complete(client.collection("healthcheck").document("healthcheck").set({}, retry=None, timeout=5)) # Ensure connection is available + return client + + +@pytest.fixture(scope="function") +def async_collection(async_client, collection): + # Use the same collection name as the collection fixture + yield async_client.collection(collection.id) + + +@pytest.fixture(scope="session") +def assert_trace_for_generator(): + def _assert_trace_for_generator(generator_func, *args, **kwargs): + txn = current_trace() + assert not isinstance(txn, DatastoreTrace) + + # Check for generator trace on collections + _trace_check = [] + for _ in generator_func(*args, **kwargs): + _trace_check.append(isinstance(current_trace(), DatastoreTrace)) + assert _trace_check and all(_trace_check) # All checks are True, and at least 1 is present. + assert current_trace() is txn # Generator trace has exited. + + return _assert_trace_for_generator + + +@pytest.fixture(scope="session") +def assert_trace_for_async_generator(loop): + def _assert_trace_for_async_generator(generator_func, *args, **kwargs): + _trace_check = [] + txn = current_trace() + assert not isinstance(txn, DatastoreTrace) + + async def coro(): + # Check for generator trace on collections + async for _ in generator_func(*args, **kwargs): + _trace_check.append(isinstance(current_trace(), DatastoreTrace)) + + loop.run_until_complete(coro()) + + assert _trace_check and all(_trace_check) # All checks are True, and at least 1 is present. + assert current_trace() is txn # Generator trace has exited. + + return _assert_trace_for_async_generator diff --git a/tests/datastore_firestore/test_async_batching.py b/tests/datastore_firestore/test_async_batching.py new file mode 100644 index 000000000..39e532a04 --- /dev/null +++ b/tests/datastore_firestore/test_async_batching.py @@ -0,0 +1,73 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def exercise_async_write_batch(async_client, async_collection): + async def _exercise_async_write_batch(): + docs = [async_collection.document(str(x)) for x in range(1, 4)] + async_batch = async_client.batch() + for doc in docs: + async_batch.set(doc, {}) + + await async_batch.commit() + + return _exercise_async_write_batch + + +def test_firestore_async_write_batch(loop, exercise_async_write_batch, instance_info): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 1), + ("Datastore/allOther", 1), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 1), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_write_batch", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_write_batch") + def _test(): + loop.run_until_complete(exercise_async_write_batch()) + + _test() + + +def test_firestore_async_write_batch_trace_node_datastore_params(loop, exercise_async_write_batch, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_write_batch()) + + _test() \ No newline at end of file diff --git a/tests/datastore_firestore/test_async_client.py b/tests/datastore_firestore/test_async_client.py new file mode 100644 index 000000000..1c7518bf0 --- /dev/null +++ b/tests/datastore_firestore/test_async_client.py @@ -0,0 +1,87 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def existing_document(collection): + doc = collection.document("document") + doc.set({"x": 1}) + return doc + + +@pytest.fixture() +def exercise_async_client(async_client, existing_document): + async def _exercise_async_client(): + assert len([_ async for _ in async_client.collections()]) >= 1 + doc = [_ async for _ in async_client.get_all([existing_document])][0] + assert doc.to_dict()["x"] == 1 + + return _exercise_async_client + + +def test_firestore_async_client(loop, exercise_async_client, instance_info): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/get_all", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_client", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_client") + def _test(): + loop.run_until_complete(exercise_async_client()) + + _test() + + +@background_task() +def test_firestore_async_client_generators(async_client, collection, assert_trace_for_async_generator): + doc = collection.document("test") + doc.set({}) + + assert_trace_for_async_generator(async_client.collections) + assert_trace_for_async_generator(async_client.get_all, [doc]) + + +def test_firestore_async_client_trace_node_datastore_params(loop, exercise_async_client, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_client()) + + _test() \ No newline at end of file diff --git a/tests/datastore_firestore/test_async_collections.py b/tests/datastore_firestore/test_async_collections.py new file mode 100644 index 000000000..214ee2939 --- /dev/null +++ b/tests/datastore_firestore/test_async_collections.py @@ -0,0 +1,94 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def exercise_async_collections(async_collection): + async def _exercise_async_collections(): + async_collection.document("DoesNotExist") + await async_collection.add({"capital": "Rome", "currency": "Euro", "language": "Italian"}, "Italy") + await async_collection.add({"capital": "Mexico City", "currency": "Peso", "language": "Spanish"}, "Mexico") + + documents_get = await async_collection.get() + assert len(documents_get) == 2 + documents_stream = [_ async for _ in async_collection.stream()] + assert len(documents_stream) == 2 + documents_list = [_ async for _ in async_collection.list_documents()] + assert len(documents_list) == 2 + + return _exercise_async_collections + + +def test_firestore_async_collections(loop, exercise_async_collections, async_collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/add" % async_collection.id, 2), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/add", 2), + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_collections", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_collections") + def _test(): + loop.run_until_complete(exercise_async_collections()) + + _test() + + +@background_task() +def test_firestore_async_collections_generators(collection, async_collection, assert_trace_for_async_generator): + collection.add({}) + collection.add({}) + assert len([_ for _ in collection.list_documents()]) == 2 + + assert_trace_for_async_generator(async_collection.stream) + assert_trace_for_async_generator(async_collection.list_documents) + + +def test_firestore_async_collections_trace_node_datastore_params(loop, exercise_async_collections, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_collections()) + + _test() \ No newline at end of file diff --git a/tests/datastore_firestore/test_async_documents.py b/tests/datastore_firestore/test_async_documents.py new file mode 100644 index 000000000..c90693208 --- /dev/null +++ b/tests/datastore_firestore/test_async_documents.py @@ -0,0 +1,108 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def exercise_async_documents(async_collection): + async def _exercise_async_documents(): + italy_doc = async_collection.document("Italy") + await italy_doc.set({"capital": "Rome", "currency": "Euro", "language": "Italian"}) + await italy_doc.get() + italian_cities = italy_doc.collection("cities") + await italian_cities.add({"capital": "Rome"}) + retrieved_coll = [_ async for _ in italy_doc.collections()] + assert len(retrieved_coll) == 1 + + usa_doc = async_collection.document("USA") + await usa_doc.create({"capital": "Washington D.C.", "currency": "Dollar", "language": "English"}) + await usa_doc.update({"president": "Joe Biden"}) + + await async_collection.document("USA").delete() + + return _exercise_async_documents + + +def test_firestore_async_documents(loop, exercise_async_documents, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/Italy/set", 1), + ("Datastore/statement/Firestore/Italy/get", 1), + ("Datastore/statement/Firestore/Italy/collections", 1), + ("Datastore/statement/Firestore/cities/add", 1), + ("Datastore/statement/Firestore/USA/create", 1), + ("Datastore/statement/Firestore/USA/update", 1), + ("Datastore/statement/Firestore/USA/delete", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/set", 1), + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/add", 1), + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/create", 1), + ("Datastore/operation/Firestore/update", 1), + ("Datastore/operation/Firestore/delete", 1), + ("Datastore/all", 7), + ("Datastore/allOther", 7), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 7), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_documents", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_documents") + def _test(): + loop.run_until_complete(exercise_async_documents()) + + _test() + + +@background_task() +def test_firestore_async_documents_generators( + collection, async_collection, assert_trace_for_async_generator, instance_info +): + subcollection_doc = collection.document("SubCollections") + subcollection_doc.set({}) + subcollection_doc.collection("collection1").add({}) + subcollection_doc.collection("collection2").add({}) + assert len([_ for _ in subcollection_doc.collections()]) == 2 + + async_subcollection = async_collection.document(subcollection_doc.id) + + assert_trace_for_async_generator(async_subcollection.collections) + + +def test_firestore_async_documents_trace_node_datastore_params(loop, exercise_async_documents, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_documents()) + + _test() \ No newline at end of file diff --git a/tests/datastore_firestore/test_async_query.py b/tests/datastore_firestore/test_async_query.py new file mode 100644 index 000000000..1bc579b7f --- /dev/null +++ b/tests/datastore_firestore/test_async_query.py @@ -0,0 +1,249 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 6): + collection.add({"x": x}) + + subcollection_doc = collection.document("subcollection") + subcollection_doc.set({}) + subcollection_doc.collection("subcollection1").add({}) + + +# ===== AsyncQuery ===== + + +@pytest.fixture() +def exercise_async_query(async_collection): + async def _exercise_async_query(): + async_query = ( + async_collection.select("x").limit(10).order_by("x").where(field_path="x", op_string="<=", value=3) + ) + assert len(await async_query.get()) == 3 + assert len([_ async for _ in async_query.stream()]) == 3 + + return _exercise_async_query + + +def test_firestore_async_query(loop, exercise_async_query, async_collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + # @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_query") + def _test(): + loop.run_until_complete(exercise_async_query()) + + _test() + + +@background_task() +def test_firestore_async_query_generators(async_collection, assert_trace_for_async_generator): + async_query = async_collection.select("x").where(field_path="x", op_string="<=", value=3) + assert_trace_for_async_generator(async_query.stream) + + +def test_firestore_async_query_trace_node_datastore_params(loop, exercise_async_query, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_query()) + + _test() + + +# ===== AsyncAggregationQuery ===== + + +@pytest.fixture() +def exercise_async_aggregation_query(async_collection): + async def _exercise_async_aggregation_query(): + async_aggregation_query = async_collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert (await async_aggregation_query.get())[0][0].value == 3 + assert [_ async for _ in async_aggregation_query.stream()][0][0].value == 3 + + return _exercise_async_aggregation_query + + +def test_firestore_async_aggregation_query(loop, exercise_async_aggregation_query, async_collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_aggregation_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_aggregation_query") + def _test(): + loop.run_until_complete(exercise_async_aggregation_query()) + + _test() + + +@background_task() +def test_firestore_async_aggregation_query_generators(async_collection, assert_trace_for_async_generator): + async_aggregation_query = async_collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert_trace_for_async_generator(async_aggregation_query.stream) + + +def test_firestore_async_aggregation_query_trace_node_datastore_params( + loop, exercise_async_aggregation_query, instance_info +): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_aggregation_query()) + + _test() + + +# ===== CollectionGroup ===== + + +@pytest.fixture() +def patch_partition_queries(monkeypatch, async_client, collection, sample_data): + """ + Partitioning is not implemented in the Firestore emulator. + + Ordinarily this method would return a coroutine that returns an async_generator of Cursor objects. + Each Cursor must point at a valid document path. To test this, we can patch the RPC to return 1 Cursor + which is pointed at any document available. The get_partitions will take that and make 2 QueryPartition + objects out of it, which should be enough to ensure we can exercise the generator's tracing. + """ + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.types.query import Cursor + + subcollection = collection.document("subcollection").collection("subcollection1") + documents = [d for d in subcollection.list_documents()] + + async def mock_partition_query(*args, **kwargs): + async def _mock_partition_query(): + yield Cursor(before=False, values=[Value(reference_value=documents[0].path)]) + + return _mock_partition_query() + + monkeypatch.setattr(async_client._firestore_api, "partition_query", mock_partition_query) + yield + + +@pytest.fixture() +def exercise_async_collection_group(async_client, async_collection): + async def _exercise_async_collection_group(): + async_collection_group = async_client.collection_group(async_collection.id) + assert len(await async_collection_group.get()) + assert len([d async for d in async_collection_group.stream()]) + + partitions = [p async for p in async_collection_group.get_partitions(1)] + assert len(partitions) == 2 + documents = [] + while partitions: + documents.extend(await partitions.pop().query().get()) + assert len(documents) == 6 + + return _exercise_async_collection_group + + +def test_firestore_async_collection_group( + loop, exercise_async_collection_group, async_collection, patch_partition_queries, instance_info +): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/get" % async_collection.id, 3), + ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/get_partitions" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 3), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/get_partitions", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_collection_group", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_collection_group") + def _test(): + loop.run_until_complete(exercise_async_collection_group()) + + _test() + + +@background_task() +def test_firestore_async_collection_group_generators( + async_client, async_collection, assert_trace_for_async_generator, patch_partition_queries +): + async_collection_group = async_client.collection_group(async_collection.id) + assert_trace_for_async_generator(async_collection_group.get_partitions, 1) + + +def test_firestore_async_collection_group_trace_node_datastore_params( + loop, exercise_async_collection_group, instance_info, patch_partition_queries +): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_collection_group()) + + _test() \ No newline at end of file diff --git a/tests/datastore_firestore/test_async_transaction.py b/tests/datastore_firestore/test_async_transaction.py new file mode 100644 index 000000000..2b8646ec5 --- /dev/null +++ b/tests/datastore_firestore/test_async_transaction.py @@ -0,0 +1,169 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 4): + collection.add({"x": x}, "doc%d" % x) + + +@pytest.fixture() +def exercise_async_transaction_commit(async_client, async_collection): + async def _exercise_async_transaction_commit(): + from google.cloud.firestore import async_transactional + + @async_transactional + async def _exercise(async_transaction): + # get a DocumentReference + with pytest.raises( + TypeError + ): # get is currently broken. It attempts to await an async_generator instead of consuming it. + [_ async for _ in async_transaction.get(async_collection.document("doc1"))] + + # get a Query + with pytest.raises( + TypeError + ): # get is currently broken. It attempts to await an async_generator instead of consuming it. + async_query = async_collection.select("x").where(field_path="x", op_string=">", value=2) + assert len([_ async for _ in async_transaction.get(async_query)]) == 1 + + # get_all on a list of DocumentReferences + with pytest.raises( + TypeError + ): # get_all is currently broken. It attempts to await an async_generator instead of consuming it. + all_docs = async_transaction.get_all([async_collection.document("doc%d" % x) for x in range(1, 4)]) + assert len([_ async for _ in all_docs]) == 3 + + # set and delete methods + async_transaction.set(async_collection.document("doc2"), {"x": 0}) + async_transaction.delete(async_collection.document("doc3")) + + await _exercise(async_client.transaction()) + assert len([_ async for _ in async_collection.list_documents()]) == 2 + + return _exercise_async_transaction_commit + + +@pytest.fixture() +def exercise_async_transaction_rollback(async_client, async_collection): + async def _exercise_async_transaction_rollback(): + from google.cloud.firestore import async_transactional + + @async_transactional + async def _exercise(async_transaction): + # set and delete methods + async_transaction.set(async_collection.document("doc2"), {"x": 99}) + async_transaction.delete(async_collection.document("doc1")) + raise RuntimeError() + + with pytest.raises(RuntimeError): + await _exercise(async_client.transaction()) + assert len([_ async for _ in async_collection.list_documents()]) == 3 + + return _exercise_async_transaction_rollback + + +def test_firestore_async_transaction_commit(loop, exercise_async_transaction_commit, async_collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + # ("Datastore/operation/Firestore/get_all", 2), + # ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + # ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 2), # Should be 5 if not for broken APIs + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_transaction") + def _test(): + loop.run_until_complete(exercise_async_transaction_commit()) + + _test() + + +def test_firestore_async_transaction_rollback( + loop, exercise_async_transaction_rollback, async_collection, instance_info +): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/rollback", 1), + ("Datastore/statement/Firestore/%s/list_documents" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_transaction") + def _test(): + loop.run_until_complete(exercise_async_transaction_rollback()) + + _test() + + +def test_firestore_async_transaction_commit_trace_node_datastore_params( + loop, exercise_async_transaction_commit, instance_info +): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_transaction_commit()) + + _test() + + +def test_firestore_async_transaction_rollback_trace_node_datastore_params( + loop, exercise_async_transaction_rollback, instance_info +): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_transaction_rollback()) + + _test() diff --git a/tests/datastore_firestore/test_batching.py b/tests/datastore_firestore/test_batching.py new file mode 100644 index 000000000..07964338c --- /dev/null +++ b/tests/datastore_firestore/test_batching.py @@ -0,0 +1,127 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + +# ===== WriteBatch ===== + + +@pytest.fixture() +def exercise_write_batch(client, collection): + def _exercise_write_batch(): + docs = [collection.document(str(x)) for x in range(1, 4)] + batch = client.batch() + for doc in docs: + batch.set(doc, {}) + + batch.commit() + + return _exercise_write_batch + + +def test_firestore_write_batch(exercise_write_batch, instance_info): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 1), + ("Datastore/allOther", 1), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 1), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_write_batch", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_write_batch") + def _test(): + exercise_write_batch() + + _test() + + +def test_firestore_write_batch_trace_node_datastore_params(exercise_write_batch, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_write_batch() + + _test() + + +# ===== BulkWriteBatch ===== + + +@pytest.fixture() +def exercise_bulk_write_batch(client, collection): + def _exercise_bulk_write_batch(): + from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch + + docs = [collection.document(str(x)) for x in range(1, 4)] + batch = BulkWriteBatch(client) + for doc in docs: + batch.set(doc, {}) + + batch.commit() + + return _exercise_bulk_write_batch + + +def test_firestore_bulk_write_batch(exercise_bulk_write_batch, instance_info): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 1), + ("Datastore/allOther", 1), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 1), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_bulk_write_batch", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_bulk_write_batch") + def _test(): + exercise_bulk_write_batch() + + _test() + + +def test_firestore_bulk_write_batch_trace_node_datastore_params(exercise_bulk_write_batch, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_bulk_write_batch() + + _test() diff --git a/tests/datastore_firestore/test_client.py b/tests/datastore_firestore/test_client.py new file mode 100644 index 000000000..81fbd181c --- /dev/null +++ b/tests/datastore_firestore/test_client.py @@ -0,0 +1,83 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def sample_data(collection): + doc = collection.document("document") + doc.set({"x": 1}) + return doc + + +@pytest.fixture() +def exercise_client(client, sample_data): + def _exercise_client(): + assert len([_ for _ in client.collections()]) + doc = [_ for _ in client.get_all([sample_data])][0] + assert doc.to_dict()["x"] == 1 + + return _exercise_client + + +def test_firestore_client(exercise_client, instance_info): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/get_all", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_client", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_client") + def _test(): + exercise_client() + + _test() + + +@background_task() +def test_firestore_client_generators(client, sample_data, assert_trace_for_generator): + assert_trace_for_generator(client.collections) + assert_trace_for_generator(client.get_all, [sample_data]) + + +def test_firestore_client_trace_node_datastore_params(exercise_client, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_client() + + _test() \ No newline at end of file diff --git a/tests/datastore_firestore/test_collections.py b/tests/datastore_firestore/test_collections.py new file mode 100644 index 000000000..2e58bbe95 --- /dev/null +++ b/tests/datastore_firestore/test_collections.py @@ -0,0 +1,94 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def exercise_collections(collection): + def _exercise_collections(): + collection.document("DoesNotExist") + collection.add({"capital": "Rome", "currency": "Euro", "language": "Italian"}, "Italy") + collection.add({"capital": "Mexico City", "currency": "Peso", "language": "Spanish"}, "Mexico") + + documents_get = collection.get() + assert len(documents_get) == 2 + documents_stream = [_ for _ in collection.stream()] + assert len(documents_stream) == 2 + documents_list = [_ for _ in collection.list_documents()] + assert len(documents_list) == 2 + + return _exercise_collections + + +def test_firestore_collections(exercise_collections, collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), + ("Datastore/statement/Firestore/%s/add" % collection.id, 2), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/add", 2), + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_collections", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_collections") + def _test(): + exercise_collections() + + _test() + + +@background_task() +def test_firestore_collections_generators(collection, assert_trace_for_generator): + collection.add({}) + collection.add({}) + assert len([_ for _ in collection.list_documents()]) == 2 + + assert_trace_for_generator(collection.stream) + assert_trace_for_generator(collection.list_documents) + + +def test_firestore_collections_trace_node_datastore_params(exercise_collections, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_collections() + + _test() diff --git a/tests/datastore_firestore/test_documents.py b/tests/datastore_firestore/test_documents.py new file mode 100644 index 000000000..ae6b94edd --- /dev/null +++ b/tests/datastore_firestore/test_documents.py @@ -0,0 +1,104 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def exercise_documents(collection): + def _exercise_documents(): + italy_doc = collection.document("Italy") + italy_doc.set({"capital": "Rome", "currency": "Euro", "language": "Italian"}) + italy_doc.get() + italian_cities = italy_doc.collection("cities") + italian_cities.add({"capital": "Rome"}) + retrieved_coll = [_ for _ in italy_doc.collections()] + assert len(retrieved_coll) == 1 + + usa_doc = collection.document("USA") + usa_doc.create({"capital": "Washington D.C.", "currency": "Dollar", "language": "English"}) + usa_doc.update({"president": "Joe Biden"}) + + collection.document("USA").delete() + + return _exercise_documents + + +def test_firestore_documents(exercise_documents, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/Italy/set", 1), + ("Datastore/statement/Firestore/Italy/get", 1), + ("Datastore/statement/Firestore/Italy/collections", 1), + ("Datastore/statement/Firestore/cities/add", 1), + ("Datastore/statement/Firestore/USA/create", 1), + ("Datastore/statement/Firestore/USA/update", 1), + ("Datastore/statement/Firestore/USA/delete", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/set", 1), + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/add", 1), + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/create", 1), + ("Datastore/operation/Firestore/update", 1), + ("Datastore/operation/Firestore/delete", 1), + ("Datastore/all", 7), + ("Datastore/allOther", 7), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 7), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_documents", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_documents") + def _test(): + exercise_documents() + + _test() + + +@background_task() +def test_firestore_documents_generators(collection, assert_trace_for_generator): + subcollection_doc = collection.document("SubCollections") + subcollection_doc.set({}) + subcollection_doc.collection("collection1").add({}) + subcollection_doc.collection("collection2").add({}) + assert len([_ for _ in subcollection_doc.collections()]) == 2 + + assert_trace_for_generator(subcollection_doc.collections) + + +def test_firestore_documents_trace_node_datastore_params(exercise_documents, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_documents() + + _test() diff --git a/tests/datastore_firestore/test_query.py b/tests/datastore_firestore/test_query.py new file mode 100644 index 000000000..6f1643c5b --- /dev/null +++ b/tests/datastore_firestore/test_query.py @@ -0,0 +1,236 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 6): + collection.add({"x": x}) + + subcollection_doc = collection.document("subcollection") + subcollection_doc.set({}) + subcollection_doc.collection("subcollection1").add({}) + + +# ===== Query ===== + + +@pytest.fixture() +def exercise_query(collection): + def _exercise_query(): + query = collection.select("x").limit(10).order_by("x").where(field_path="x", op_string="<=", value=3) + assert len(query.get()) == 3 + assert len([_ for _ in query.stream()]) == 3 + + return _exercise_query + + +def test_firestore_query(exercise_query, collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_query") + def _test(): + exercise_query() + + _test() + + +@background_task() +def test_firestore_query_generators(collection, assert_trace_for_generator): + query = collection.select("x").where(field_path="x", op_string="<=", value=3) + assert_trace_for_generator(query.stream) + + +def test_firestore_query_trace_node_datastore_params(exercise_query, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_query() + + _test() + + +# ===== AggregationQuery ===== + + +@pytest.fixture() +def exercise_aggregation_query(collection): + def _exercise_aggregation_query(): + aggregation_query = collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert aggregation_query.get()[0][0].value == 3 + assert [_ for _ in aggregation_query.stream()][0][0].value == 3 + + return _exercise_aggregation_query + + +def test_firestore_aggregation_query(exercise_aggregation_query, collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_aggregation_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_aggregation_query") + def _test(): + exercise_aggregation_query() + + _test() + + +@background_task() +def test_firestore_aggregation_query_generators(collection, assert_trace_for_generator): + aggregation_query = collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert_trace_for_generator(aggregation_query.stream) + + +def test_firestore_aggregation_query_trace_node_datastore_params(exercise_aggregation_query, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_aggregation_query() + + _test() + + +# ===== CollectionGroup ===== + + +@pytest.fixture() +def patch_partition_queries(monkeypatch, client, collection, sample_data): + """ + Partitioning is not implemented in the Firestore emulator. + + Ordinarily this method would return a generator of Cursor objects. Each Cursor must point at a valid document path. + To test this, we can patch the RPC to return 1 Cursor which is pointed at any document available. + The get_partitions will take that and make 2 QueryPartition objects out of it, which should be enough to ensure + we can exercise the generator's tracing. + """ + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.types.query import Cursor + + subcollection = collection.document("subcollection").collection("subcollection1") + documents = [d for d in subcollection.list_documents()] + + def mock_partition_query(*args, **kwargs): + yield Cursor(before=False, values=[Value(reference_value=documents[0].path)]) + + monkeypatch.setattr(client._firestore_api, "partition_query", mock_partition_query) + yield + + +@pytest.fixture() +def exercise_collection_group(client, collection, patch_partition_queries): + def _exercise_collection_group(): + collection_group = client.collection_group(collection.id) + assert len(collection_group.get()) + assert len([d for d in collection_group.stream()]) + + partitions = [p for p in collection_group.get_partitions(1)] + assert len(partitions) == 2 + documents = [] + while partitions: + documents.extend(partitions.pop().query().get()) + assert len(documents) == 6 + + return _exercise_collection_group + + +def test_firestore_collection_group(exercise_collection_group, client, collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/get" % collection.id, 3), + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get_partitions" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 3), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/get_partitions", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_collection_group", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_collection_group") + def _test(): + exercise_collection_group() + + _test() + + +@background_task() +def test_firestore_collection_group_generators(client, collection, assert_trace_for_generator, patch_partition_queries): + collection_group = client.collection_group(collection.id) + assert_trace_for_generator(collection_group.get_partitions, 1) + + +def test_firestore_collection_group_trace_node_datastore_params(exercise_collection_group, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_collection_group() + + _test() diff --git a/tests/datastore_firestore/test_transaction.py b/tests/datastore_firestore/test_transaction.py new file mode 100644 index 000000000..59d496a00 --- /dev/null +++ b/tests/datastore_firestore/test_transaction.py @@ -0,0 +1,153 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 4): + collection.add({"x": x}, "doc%d" % x) + + +@pytest.fixture() +def exercise_transaction_commit(client, collection): + def _exercise_transaction_commit(): + from google.cloud.firestore_v1.transaction import transactional + + @transactional + def _exercise(transaction): + # get a DocumentReference + [_ for _ in transaction.get(collection.document("doc1"))] + + # get a Query + query = collection.select("x").where(field_path="x", op_string=">", value=2) + assert len([_ for _ in transaction.get(query)]) == 1 + + # get_all on a list of DocumentReferences + all_docs = transaction.get_all([collection.document("doc%d" % x) for x in range(1, 4)]) + assert len([_ for _ in all_docs]) == 3 + + # set and delete methods + transaction.set(collection.document("doc2"), {"x": 0}) + transaction.delete(collection.document("doc3")) + + _exercise(client.transaction()) + assert len([_ for _ in collection.list_documents()]) == 2 + + return _exercise_transaction_commit + + +@pytest.fixture() +def exercise_transaction_rollback(client, collection): + def _exercise_transaction_rollback(): + from google.cloud.firestore_v1.transaction import transactional + + @transactional + def _exercise(transaction): + # set and delete methods + transaction.set(collection.document("doc2"), {"x": 99}) + transaction.delete(collection.document("doc1")) + raise RuntimeError() + + with pytest.raises(RuntimeError): + _exercise(client.transaction()) + assert len([_ for _ in collection.list_documents()]) == 3 + + return _exercise_transaction_rollback + + +def test_firestore_transaction_commit(exercise_transaction_commit, collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ("Datastore/operation/Firestore/get_all", 2), + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_transaction") + def _test(): + exercise_transaction_commit() + + _test() + + +def test_firestore_transaction_rollback(exercise_transaction_rollback, collection, instance_info): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/rollback", 1), + ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/instance/Firestore/%s/%s" % (instance_info["host"], instance_info["port_path_or_id"]), 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_transaction") + def _test(): + exercise_transaction_rollback() + + _test() + + +def test_firestore_transaction_commit_trace_node_datastore_params(exercise_transaction_commit, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_transaction_commit() + + _test() + + +def test_firestore_transaction_rollback_trace_node_datastore_params(exercise_transaction_rollback, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_transaction_rollback() + + _test() diff --git a/tests/datastore_mysql/test_database.py b/tests/datastore_mysql/test_database.py index 2fc8ca129..8f8641903 100644 --- a/tests/datastore_mysql/test_database.py +++ b/tests/datastore_mysql/test_database.py @@ -13,11 +13,15 @@ # limitations under the License. import mysql.connector - -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_database_trace_inputs import validate_database_trace_inputs - from testing_support.db_settings import mysql_settings +from testing_support.util import instance_hostname +from testing_support.validators.validate_database_trace_inputs import ( + validate_database_trace_inputs, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + from newrelic.api.background_task import background_task DB_SETTINGS = mysql_settings() @@ -27,80 +31,95 @@ mysql_version = tuple(int(x) for x in mysql.connector.__version__.split(".")[:3]) if mysql_version >= (8, 0, 30): - _connector_metric_name = 'Function/mysql.connector.pooling:connect' + _connector_metric_name = "Function/mysql.connector.pooling:connect" else: - _connector_metric_name = 'Function/mysql.connector:connect' + _connector_metric_name = "Function/mysql.connector:connect" _test_execute_via_cursor_scoped_metrics = [ - (_connector_metric_name, 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/select' % DB_NAMESPACE, 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/insert' % DB_NAMESPACE, 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/update' % DB_NAMESPACE, 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/delete' % DB_NAMESPACE, 1), - ('Datastore/operation/MySQL/drop', 2), - ('Datastore/operation/MySQL/create', 2), - ('Datastore/statement/MySQL/%s/call' % DB_PROCEDURE, 1), - ('Datastore/operation/MySQL/commit', 2), - ('Datastore/operation/MySQL/rollback', 1)] + (_connector_metric_name, 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/select" % DB_NAMESPACE, 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/insert" % DB_NAMESPACE, 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/update" % DB_NAMESPACE, 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/delete" % DB_NAMESPACE, 1), + ("Datastore/operation/MySQL/drop", 2), + ("Datastore/operation/MySQL/create", 2), + ("Datastore/statement/MySQL/%s/call" % DB_PROCEDURE, 1), + ("Datastore/operation/MySQL/commit", 2), + ("Datastore/operation/MySQL/rollback", 1), +] _test_execute_via_cursor_rollup_metrics = [ - ('Datastore/all', 13), - ('Datastore/allOther', 13), - ('Datastore/MySQL/all', 13), - ('Datastore/MySQL/allOther', 13), - ('Datastore/operation/MySQL/select', 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/select' % DB_NAMESPACE, 1), - ('Datastore/operation/MySQL/insert', 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/insert' % DB_NAMESPACE, 1), - ('Datastore/operation/MySQL/update', 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/update' % DB_NAMESPACE, 1), - ('Datastore/operation/MySQL/delete', 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/delete' % DB_NAMESPACE, 1), - ('Datastore/statement/MySQL/%s/call' % DB_PROCEDURE, 1), - ('Datastore/operation/MySQL/call', 1), - ('Datastore/operation/MySQL/drop', 2), - ('Datastore/operation/MySQL/create', 2), - ('Datastore/operation/MySQL/commit', 2), - ('Datastore/operation/MySQL/rollback', 1)] - -@validate_transaction_metrics('test_database:test_execute_via_cursor', - scoped_metrics=_test_execute_via_cursor_scoped_metrics, - rollup_metrics=_test_execute_via_cursor_rollup_metrics, - background_task=True) + ("Datastore/all", 13), + ("Datastore/allOther", 13), + ("Datastore/MySQL/all", 13), + ("Datastore/MySQL/allOther", 13), + ("Datastore/operation/MySQL/select", 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/select" % DB_NAMESPACE, 1), + ("Datastore/operation/MySQL/insert", 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/insert" % DB_NAMESPACE, 1), + ("Datastore/operation/MySQL/update", 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/update" % DB_NAMESPACE, 1), + ("Datastore/operation/MySQL/delete", 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/delete" % DB_NAMESPACE, 1), + ("Datastore/statement/MySQL/%s/call" % DB_PROCEDURE, 1), + ("Datastore/operation/MySQL/call", 1), + ("Datastore/operation/MySQL/drop", 2), + ("Datastore/operation/MySQL/create", 2), + ("Datastore/operation/MySQL/commit", 2), + ("Datastore/operation/MySQL/rollback", 1), + ("Datastore/instance/MySQL/%s/%s" % (instance_hostname(DB_SETTINGS["host"]), DB_SETTINGS["port"]), 12), +] + + +@validate_transaction_metrics( + "test_database:test_execute_via_cursor", + scoped_metrics=_test_execute_via_cursor_scoped_metrics, + rollup_metrics=_test_execute_via_cursor_rollup_metrics, + background_task=True, +) @validate_database_trace_inputs(sql_parameters_type=dict) @background_task() def test_execute_via_cursor(table_name): - connection = mysql.connector.connect(db=DB_SETTINGS['name'], - user=DB_SETTINGS['user'], passwd=DB_SETTINGS['password'], - host=DB_SETTINGS['host'], port=DB_SETTINGS['port']) + connection = mysql.connector.connect( + db=DB_SETTINGS["name"], + user=DB_SETTINGS["user"], + passwd=DB_SETTINGS["password"], + host=DB_SETTINGS["host"], + port=DB_SETTINGS["port"], + ) cursor = connection.cursor() cursor.execute("""drop table if exists `%s`""" % table_name) - cursor.execute("""create table %s """ - """(a integer, b real, c text)""" % table_name) + cursor.execute("""create table %s """ """(a integer, b real, c text)""" % table_name) - cursor.executemany("""insert into `%s` """ % table_name + - """values (%(a)s, %(b)s, %(c)s)""", [dict(a=1, b=1.0, c='1.0'), - dict(a=2, b=2.2, c='2.2'), dict(a=3, b=3.3, c='3.3')]) + cursor.executemany( + """insert into `%s` """ % table_name + """values (%(a)s, %(b)s, %(c)s)""", + [dict(a=1, b=1.0, c="1.0"), dict(a=2, b=2.2, c="2.2"), dict(a=3, b=3.3, c="3.3")], + ) cursor.execute("""select * from %s""" % table_name) - for row in cursor: pass + for row in cursor: + pass - cursor.execute("""update `%s` """ % table_name + - """set a=%(a)s, b=%(b)s, c=%(c)s where a=%(old_a)s""", - dict(a=4, b=4.0, c='4.0', old_a=1)) + cursor.execute( + """update `%s` """ % table_name + """set a=%(a)s, b=%(b)s, c=%(c)s where a=%(old_a)s""", + dict(a=4, b=4.0, c="4.0", old_a=1), + ) cursor.execute("""delete from `%s` where a=2""" % table_name) cursor.execute("""drop procedure if exists %s""" % DB_PROCEDURE) - cursor.execute("""CREATE PROCEDURE %s() + cursor.execute( + """CREATE PROCEDURE %s() BEGIN SELECT 'Hello World!'; - END""" % DB_PROCEDURE) + END""" + % DB_PROCEDURE + ) cursor.callproc("%s" % DB_PROCEDURE) @@ -108,76 +127,92 @@ def test_execute_via_cursor(table_name): connection.rollback() connection.commit() + _test_connect_using_alias_scoped_metrics = [ - (_connector_metric_name, 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/select' % DB_NAMESPACE, 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/insert' % DB_NAMESPACE, 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/update' % DB_NAMESPACE, 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/delete' % DB_NAMESPACE, 1), - ('Datastore/operation/MySQL/drop', 2), - ('Datastore/operation/MySQL/create', 2), - ('Datastore/statement/MySQL/%s/call' % DB_PROCEDURE, 1), - ('Datastore/operation/MySQL/commit', 2), - ('Datastore/operation/MySQL/rollback', 1)] + (_connector_metric_name, 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/select" % DB_NAMESPACE, 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/insert" % DB_NAMESPACE, 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/update" % DB_NAMESPACE, 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/delete" % DB_NAMESPACE, 1), + ("Datastore/operation/MySQL/drop", 2), + ("Datastore/operation/MySQL/create", 2), + ("Datastore/statement/MySQL/%s/call" % DB_PROCEDURE, 1), + ("Datastore/operation/MySQL/commit", 2), + ("Datastore/operation/MySQL/rollback", 1), +] _test_connect_using_alias_rollup_metrics = [ - ('Datastore/all', 13), - ('Datastore/allOther', 13), - ('Datastore/MySQL/all', 13), - ('Datastore/MySQL/allOther', 13), - ('Datastore/operation/MySQL/select', 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/select' % DB_NAMESPACE, 1), - ('Datastore/operation/MySQL/insert', 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/insert' % DB_NAMESPACE, 1), - ('Datastore/operation/MySQL/update', 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/update' % DB_NAMESPACE, 1), - ('Datastore/operation/MySQL/delete', 1), - ('Datastore/statement/MySQL/datastore_mysql_%s/delete' % DB_NAMESPACE, 1), - ('Datastore/statement/MySQL/%s/call' % DB_PROCEDURE, 1), - ('Datastore/operation/MySQL/call', 1), - ('Datastore/operation/MySQL/drop', 2), - ('Datastore/operation/MySQL/create', 2), - ('Datastore/operation/MySQL/commit', 2), - ('Datastore/operation/MySQL/rollback', 1)] - -@validate_transaction_metrics('test_database:test_connect_using_alias', - scoped_metrics=_test_connect_using_alias_scoped_metrics, - rollup_metrics=_test_connect_using_alias_rollup_metrics, - background_task=True) + ("Datastore/all", 13), + ("Datastore/allOther", 13), + ("Datastore/MySQL/all", 13), + ("Datastore/MySQL/allOther", 13), + ("Datastore/operation/MySQL/select", 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/select" % DB_NAMESPACE, 1), + ("Datastore/operation/MySQL/insert", 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/insert" % DB_NAMESPACE, 1), + ("Datastore/operation/MySQL/update", 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/update" % DB_NAMESPACE, 1), + ("Datastore/operation/MySQL/delete", 1), + ("Datastore/statement/MySQL/datastore_mysql_%s/delete" % DB_NAMESPACE, 1), + ("Datastore/statement/MySQL/%s/call" % DB_PROCEDURE, 1), + ("Datastore/operation/MySQL/call", 1), + ("Datastore/operation/MySQL/drop", 2), + ("Datastore/operation/MySQL/create", 2), + ("Datastore/operation/MySQL/commit", 2), + ("Datastore/operation/MySQL/rollback", 1), + ("Datastore/instance/MySQL/%s/%s" % (instance_hostname(DB_SETTINGS["host"]), DB_SETTINGS["port"]), 12), +] + + +@validate_transaction_metrics( + "test_database:test_connect_using_alias", + scoped_metrics=_test_connect_using_alias_scoped_metrics, + rollup_metrics=_test_connect_using_alias_rollup_metrics, + background_task=True, +) @validate_database_trace_inputs(sql_parameters_type=dict) @background_task() def test_connect_using_alias(table_name): - connection = mysql.connector.connect(db=DB_SETTINGS['name'], - user=DB_SETTINGS['user'], passwd=DB_SETTINGS['password'], - host=DB_SETTINGS['host'], port=DB_SETTINGS['port']) + connection = mysql.connector.connect( + db=DB_SETTINGS["name"], + user=DB_SETTINGS["user"], + passwd=DB_SETTINGS["password"], + host=DB_SETTINGS["host"], + port=DB_SETTINGS["port"], + ) cursor = connection.cursor() cursor.execute("""drop table if exists `%s`""" % table_name) - cursor.execute("""create table %s """ - """(a integer, b real, c text)""" % table_name) + cursor.execute("""create table %s """ """(a integer, b real, c text)""" % table_name) - cursor.executemany("""insert into `%s` """ % table_name + - """values (%(a)s, %(b)s, %(c)s)""", [dict(a=1, b=1.0, c='1.0'), - dict(a=2, b=2.2, c='2.2'), dict(a=3, b=3.3, c='3.3')]) + cursor.executemany( + """insert into `%s` """ % table_name + """values (%(a)s, %(b)s, %(c)s)""", + [dict(a=1, b=1.0, c="1.0"), dict(a=2, b=2.2, c="2.2"), dict(a=3, b=3.3, c="3.3")], + ) cursor.execute("""select * from %s""" % table_name) - for row in cursor: pass + for row in cursor: + pass - cursor.execute("""update `%s` """ % table_name + - """set a=%(a)s, b=%(b)s, c=%(c)s where a=%(old_a)s""", - dict(a=4, b=4.0, c='4.0', old_a=1)) + cursor.execute( + """update `%s` """ % table_name + """set a=%(a)s, b=%(b)s, c=%(c)s where a=%(old_a)s""", + dict(a=4, b=4.0, c="4.0", old_a=1), + ) cursor.execute("""delete from `%s` where a=2""" % table_name) cursor.execute("""drop procedure if exists %s""" % DB_PROCEDURE) - cursor.execute("""CREATE PROCEDURE %s() + cursor.execute( + """CREATE PROCEDURE %s() BEGIN SELECT 'Hello World!'; - END""" % DB_PROCEDURE) + END""" + % DB_PROCEDURE + ) cursor.callproc("%s" % DB_PROCEDURE) diff --git a/tests/datastore_postgresql/conftest.py b/tests/datastore_postgresql/conftest.py index 624fb4726..4a25f2574 100644 --- a/tests/datastore_postgresql/conftest.py +++ b/tests/datastore_postgresql/conftest.py @@ -12,21 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - -from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture # noqa: F401; pylint: disable=W0611 - +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) _default_settings = { - 'transaction_tracer.explain_threshold': 0.0, - 'transaction_tracer.transaction_threshold': 0.0, - 'transaction_tracer.stack_trace_threshold': 0.0, - 'debug.log_data_collector_payloads': True, - 'debug.record_transaction_failure': True, - 'debug.log_explain_plan_queries': True + "transaction_tracer.explain_threshold": 0.0, + "transaction_tracer.transaction_threshold": 0.0, + "transaction_tracer.stack_trace_threshold": 0.0, + "debug.log_data_collector_payloads": True, + "debug.record_transaction_failure": True, + "debug.log_explain_plan_queries": True, } collector_agent_registration = collector_agent_registration_fixture( - app_name='Python Agent Test (datastore_postgresql)', - default_settings=_default_settings, - linked_applications=['Python Agent Test (datastore)']) + app_name="Python Agent Test (datastore_postgresql)", + default_settings=_default_settings, + linked_applications=["Python Agent Test (datastore)"], +) diff --git a/tests/datastore_postgresql/test_database.py b/tests/datastore_postgresql/test_database.py index 2ea930b05..cf432d174 100644 --- a/tests/datastore_postgresql/test_database.py +++ b/tests/datastore_postgresql/test_database.py @@ -13,15 +13,14 @@ # limitations under the License. import postgresql.driver.dbapi20 - - -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics - +from testing_support.db_settings import postgresql_settings +from testing_support.util import instance_hostname from testing_support.validators.validate_database_trace_inputs import ( validate_database_trace_inputs, ) - -from testing_support.db_settings import postgresql_settings +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.api.background_task import background_task @@ -41,13 +40,14 @@ ("Datastore/operation/Postgres/create", 1), ("Datastore/operation/Postgres/commit", 3), ("Datastore/operation/Postgres/rollback", 1), + ("Datastore/operation/Postgres/other", 1), ] _test_execute_via_cursor_rollup_metrics = [ - ("Datastore/all", 13), - ("Datastore/allOther", 13), - ("Datastore/Postgres/all", 13), - ("Datastore/Postgres/allOther", 13), + ("Datastore/all", 14), + ("Datastore/allOther", 14), + ("Datastore/Postgres/all", 14), + ("Datastore/Postgres/allOther", 14), ("Datastore/operation/Postgres/select", 1), ("Datastore/statement/Postgres/%s/select" % DB_SETTINGS["table_name"], 1), ("Datastore/operation/Postgres/insert", 1), @@ -63,6 +63,11 @@ ("Datastore/operation/Postgres/call", 2), ("Datastore/operation/Postgres/commit", 3), ("Datastore/operation/Postgres/rollback", 1), + ("Datastore/operation/Postgres/other", 1), + ("Datastore/instance/Postgres/%s/%s" % (instance_hostname(DB_SETTINGS["host"]), DB_SETTINGS["port"]), 13), + ("Function/postgresql.driver.dbapi20:connect", 1), + ("Function/postgresql.driver.dbapi20:Connection.__enter__", 1), + ("Function/postgresql.driver.dbapi20:Connection.__exit__", 1), ] @@ -82,30 +87,27 @@ def test_execute_via_cursor(): host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], ) as connection: - cursor = connection.cursor() cursor.execute("""drop table if exists %s""" % DB_SETTINGS["table_name"]) - cursor.execute( - """create table %s """ % DB_SETTINGS["table_name"] - + """(a integer, b real, c text)""" - ) + cursor.execute("""create table %s """ % DB_SETTINGS["table_name"] + """(a integer, b real, c text)""") cursor.executemany( - """insert into %s """ % DB_SETTINGS["table_name"] - + """values (%s, %s, %s)""", + """insert into %s """ % DB_SETTINGS["table_name"] + """values (%s, %s, %s)""", [(1, 1.0, "1.0"), (2, 2.2, "2.2"), (3, 3.3, "3.3")], ) cursor.execute("""select * from %s""" % DB_SETTINGS["table_name"]) - for row in cursor: - pass + cursor.execute( + """with temporaryTable (averageValue) as (select avg(b) from %s) """ % DB_SETTINGS["table_name"] + + """select * from %s,temporaryTable """ % DB_SETTINGS["table_name"] + + """where %s.b > temporaryTable.averageValue""" % DB_SETTINGS["table_name"] + ) cursor.execute( - """update %s """ % DB_SETTINGS["table_name"] - + """set a=%s, b=%s, c=%s where a=%s""", + """update %s """ % DB_SETTINGS["table_name"] + """set a=%s, b=%s, c=%s where a=%s""", (4, 4.0, "4.0", 1), ) @@ -152,7 +154,6 @@ def test_rollback_on_exception(): host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], ): - raise RuntimeError("error") except RuntimeError: diff --git a/tests/datastore_psycopg2cffi/test_database.py b/tests/datastore_psycopg2cffi/test_database.py index 54ff6ad09..939c5cabc 100644 --- a/tests/datastore_psycopg2cffi/test_database.py +++ b/tests/datastore_psycopg2cffi/test_database.py @@ -15,166 +15,190 @@ import psycopg2cffi import psycopg2cffi.extensions import psycopg2cffi.extras - -from testing_support.fixtures import validate_stats_engine_explain_plan_output_is_none -from testing_support.validators.validate_transaction_errors import validate_transaction_errors -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_transaction_slow_sql_count import \ - validate_transaction_slow_sql_count -from testing_support.validators.validate_database_trace_inputs import validate_database_trace_inputs - from testing_support.db_settings import postgresql_settings +from testing_support.fixtures import validate_stats_engine_explain_plan_output_is_none +from testing_support.util import instance_hostname +from testing_support.validators.validate_database_trace_inputs import ( + validate_database_trace_inputs, +) +from testing_support.validators.validate_transaction_errors import ( + validate_transaction_errors, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_transaction_slow_sql_count import ( + validate_transaction_slow_sql_count, +) from newrelic.api.background_task import background_task DB_SETTINGS = postgresql_settings()[0] _test_execute_via_cursor_scoped_metrics = [ - ('Function/psycopg2cffi:connect', 1), - ('Function/psycopg2cffi._impl.connection:Connection.__enter__', 1), - ('Function/psycopg2cffi._impl.connection:Connection.__exit__', 1), - ('Datastore/statement/Postgres/%s/select' % DB_SETTINGS["table_name"], 1), - ('Datastore/statement/Postgres/%s/insert' % DB_SETTINGS["table_name"], 1), - ('Datastore/statement/Postgres/%s/update' % DB_SETTINGS["table_name"], 1), - ('Datastore/statement/Postgres/%s/delete' % DB_SETTINGS["table_name"], 1), - ('Datastore/statement/Postgres/now/call', 1), - ('Datastore/statement/Postgres/pg_sleep/call', 1), - ('Datastore/operation/Postgres/drop', 1), - ('Datastore/operation/Postgres/create', 1), - ('Datastore/operation/Postgres/commit', 3), - ('Datastore/operation/Postgres/rollback', 1)] + ("Function/psycopg2cffi:connect", 1), + ("Function/psycopg2cffi._impl.connection:Connection.__enter__", 1), + ("Function/psycopg2cffi._impl.connection:Connection.__exit__", 1), + ("Datastore/statement/Postgres/%s/select" % DB_SETTINGS["table_name"], 1), + ("Datastore/statement/Postgres/%s/insert" % DB_SETTINGS["table_name"], 1), + ("Datastore/statement/Postgres/%s/update" % DB_SETTINGS["table_name"], 1), + ("Datastore/statement/Postgres/%s/delete" % DB_SETTINGS["table_name"], 1), + ("Datastore/statement/Postgres/now/call", 1), + ("Datastore/statement/Postgres/pg_sleep/call", 1), + ("Datastore/operation/Postgres/drop", 1), + ("Datastore/operation/Postgres/create", 1), + ("Datastore/operation/Postgres/commit", 3), + ("Datastore/operation/Postgres/rollback", 1), +] _test_execute_via_cursor_rollup_metrics = [ - ('Datastore/all', 13), - ('Datastore/allOther', 13), - ('Datastore/Postgres/all', 13), - ('Datastore/Postgres/allOther', 13), - ('Datastore/operation/Postgres/select', 1), - ('Datastore/statement/Postgres/%s/select' % DB_SETTINGS["table_name"], 1), - ('Datastore/operation/Postgres/insert', 1), - ('Datastore/statement/Postgres/%s/insert' % DB_SETTINGS["table_name"], 1), - ('Datastore/operation/Postgres/update', 1), - ('Datastore/statement/Postgres/%s/update' % DB_SETTINGS["table_name"], 1), - ('Datastore/operation/Postgres/delete', 1), - ('Datastore/statement/Postgres/%s/delete' % DB_SETTINGS["table_name"], 1), - ('Datastore/operation/Postgres/drop', 1), - ('Datastore/operation/Postgres/create', 1), - ('Datastore/statement/Postgres/now/call', 1), - ('Datastore/statement/Postgres/pg_sleep/call', 1), - ('Datastore/operation/Postgres/call', 2), - ('Datastore/operation/Postgres/commit', 3), - ('Datastore/operation/Postgres/rollback', 1)] - - -@validate_transaction_metrics('test_database:test_execute_via_cursor', - scoped_metrics=_test_execute_via_cursor_scoped_metrics, - rollup_metrics=_test_execute_via_cursor_rollup_metrics, - background_task=True) + ("Datastore/all", 13), + ("Datastore/allOther", 13), + ("Datastore/Postgres/all", 13), + ("Datastore/Postgres/allOther", 13), + ("Datastore/operation/Postgres/select", 1), + ("Datastore/statement/Postgres/%s/select" % DB_SETTINGS["table_name"], 1), + ("Datastore/operation/Postgres/insert", 1), + ("Datastore/statement/Postgres/%s/insert" % DB_SETTINGS["table_name"], 1), + ("Datastore/operation/Postgres/update", 1), + ("Datastore/statement/Postgres/%s/update" % DB_SETTINGS["table_name"], 1), + ("Datastore/operation/Postgres/delete", 1), + ("Datastore/statement/Postgres/%s/delete" % DB_SETTINGS["table_name"], 1), + ("Datastore/operation/Postgres/drop", 1), + ("Datastore/operation/Postgres/create", 1), + ("Datastore/statement/Postgres/now/call", 1), + ("Datastore/statement/Postgres/pg_sleep/call", 1), + ("Datastore/operation/Postgres/call", 2), + ("Datastore/operation/Postgres/commit", 3), + ("Datastore/operation/Postgres/rollback", 1), + ("Datastore/instance/Postgres/%s/%s" % (instance_hostname(DB_SETTINGS["host"]), DB_SETTINGS["port"]), 12), +] + + +@validate_transaction_metrics( + "test_database:test_execute_via_cursor", + scoped_metrics=_test_execute_via_cursor_scoped_metrics, + rollup_metrics=_test_execute_via_cursor_rollup_metrics, + background_task=True, +) @validate_database_trace_inputs(sql_parameters_type=tuple) @background_task() def test_execute_via_cursor(): with psycopg2cffi.connect( - database=DB_SETTINGS['name'], user=DB_SETTINGS['user'], - password=DB_SETTINGS['password'], host=DB_SETTINGS['host'], - port=DB_SETTINGS['port']) as connection: + database=DB_SETTINGS["name"], + user=DB_SETTINGS["user"], + password=DB_SETTINGS["password"], + host=DB_SETTINGS["host"], + port=DB_SETTINGS["port"], + ) as connection: cursor = connection.cursor() psycopg2cffi.extensions.register_type(psycopg2cffi.extensions.UNICODE) - psycopg2cffi.extensions.register_type( - psycopg2cffi.extensions.UNICODE, - connection) - psycopg2cffi.extensions.register_type( - psycopg2cffi.extensions.UNICODE, - cursor) + psycopg2cffi.extensions.register_type(psycopg2cffi.extensions.UNICODE, connection) + psycopg2cffi.extensions.register_type(psycopg2cffi.extensions.UNICODE, cursor) cursor.execute("""drop table if exists %s""" % DB_SETTINGS["table_name"]) - cursor.execute("""create table %s """ % DB_SETTINGS["table_name"] + - """(a integer, b real, c text)""") + cursor.execute("""create table %s """ % DB_SETTINGS["table_name"] + """(a integer, b real, c text)""") - cursor.executemany("""insert into %s """ % DB_SETTINGS["table_name"] + - """values (%s, %s, %s)""", [(1, 1.0, '1.0'), - (2, 2.2, '2.2'), (3, 3.3, '3.3')]) + cursor.executemany( + """insert into %s """ % DB_SETTINGS["table_name"] + """values (%s, %s, %s)""", + [(1, 1.0, "1.0"), (2, 2.2, "2.2"), (3, 3.3, "3.3")], + ) cursor.execute("""select * from %s""" % DB_SETTINGS["table_name"]) for row in cursor: pass - cursor.execute("""update %s""" % DB_SETTINGS["table_name"] + """ set a=%s, b=%s, """ - """c=%s where a=%s""", (4, 4.0, '4.0', 1)) + cursor.execute( + """update %s""" % DB_SETTINGS["table_name"] + """ set a=%s, b=%s, """ """c=%s where a=%s""", + (4, 4.0, "4.0", 1), + ) cursor.execute("""delete from %s where a=2""" % DB_SETTINGS["table_name"]) connection.commit() - cursor.callproc('now') - cursor.callproc('pg_sleep', (0,)) + cursor.callproc("now") + cursor.callproc("pg_sleep", (0,)) connection.rollback() connection.commit() _test_rollback_on_exception_scoped_metrics = [ - ('Function/psycopg2cffi:connect', 1), - ('Function/psycopg2cffi._impl.connection:Connection.__enter__', 1), - ('Function/psycopg2cffi._impl.connection:Connection.__exit__', 1), - ('Datastore/operation/Postgres/rollback', 1)] + ("Function/psycopg2cffi:connect", 1), + ("Function/psycopg2cffi._impl.connection:Connection.__enter__", 1), + ("Function/psycopg2cffi._impl.connection:Connection.__exit__", 1), + ("Datastore/operation/Postgres/rollback", 1), +] _test_rollback_on_exception_rollup_metrics = [ - ('Datastore/all', 2), - ('Datastore/allOther', 2), - ('Datastore/Postgres/all', 2), - ('Datastore/Postgres/allOther', 2)] - - -@validate_transaction_metrics('test_database:test_rollback_on_exception', - scoped_metrics=_test_rollback_on_exception_scoped_metrics, - rollup_metrics=_test_rollback_on_exception_rollup_metrics, - background_task=True) + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ("Datastore/Postgres/all", 2), + ("Datastore/Postgres/allOther", 2), +] + + +@validate_transaction_metrics( + "test_database:test_rollback_on_exception", + scoped_metrics=_test_rollback_on_exception_scoped_metrics, + rollup_metrics=_test_rollback_on_exception_rollup_metrics, + background_task=True, +) @validate_database_trace_inputs(sql_parameters_type=tuple) @background_task() def test_rollback_on_exception(): try: with psycopg2cffi.connect( - database=DB_SETTINGS['name'], user=DB_SETTINGS['user'], - password=DB_SETTINGS['password'], host=DB_SETTINGS['host'], - port=DB_SETTINGS['port']): - - raise RuntimeError('error') + database=DB_SETTINGS["name"], + user=DB_SETTINGS["user"], + password=DB_SETTINGS["password"], + host=DB_SETTINGS["host"], + port=DB_SETTINGS["port"], + ): + + raise RuntimeError("error") except RuntimeError: pass _test_async_mode_scoped_metrics = [ - ('Function/psycopg2cffi:connect', 1), - ('Datastore/statement/Postgres/%s/select' % DB_SETTINGS["table_name"], 1), - ('Datastore/statement/Postgres/%s/insert' % DB_SETTINGS["table_name"], 1), - ('Datastore/operation/Postgres/drop', 1), - ('Datastore/operation/Postgres/create', 1)] + ("Function/psycopg2cffi:connect", 1), + ("Datastore/statement/Postgres/%s/select" % DB_SETTINGS["table_name"], 1), + ("Datastore/statement/Postgres/%s/insert" % DB_SETTINGS["table_name"], 1), + ("Datastore/operation/Postgres/drop", 1), + ("Datastore/operation/Postgres/create", 1), +] _test_async_mode_rollup_metrics = [ - ('Datastore/all', 5), - ('Datastore/allOther', 5), - ('Datastore/Postgres/all', 5), - ('Datastore/Postgres/allOther', 5), - ('Datastore/operation/Postgres/select', 1), - ('Datastore/statement/Postgres/%s/select' % DB_SETTINGS["table_name"], 1), - ('Datastore/operation/Postgres/insert', 1), - ('Datastore/statement/Postgres/%s/insert' % DB_SETTINGS["table_name"], 1), - ('Datastore/operation/Postgres/drop', 1), - ('Datastore/operation/Postgres/create', 1)] + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ("Datastore/Postgres/all", 5), + ("Datastore/Postgres/allOther", 5), + ("Datastore/operation/Postgres/select", 1), + ("Datastore/statement/Postgres/%s/select" % DB_SETTINGS["table_name"], 1), + ("Datastore/operation/Postgres/insert", 1), + ("Datastore/statement/Postgres/%s/insert" % DB_SETTINGS["table_name"], 1), + ("Datastore/operation/Postgres/drop", 1), + ("Datastore/operation/Postgres/create", 1), + ("Datastore/instance/Postgres/%s/%s" % (instance_hostname(DB_SETTINGS["host"]), DB_SETTINGS["port"]), 4), +] @validate_stats_engine_explain_plan_output_is_none() @validate_transaction_slow_sql_count(num_slow_sql=4) @validate_database_trace_inputs(sql_parameters_type=tuple) -@validate_transaction_metrics('test_database:test_async_mode', - scoped_metrics=_test_async_mode_scoped_metrics, - rollup_metrics=_test_async_mode_rollup_metrics, - background_task=True) +@validate_transaction_metrics( + "test_database:test_async_mode", + scoped_metrics=_test_async_mode_scoped_metrics, + rollup_metrics=_test_async_mode_rollup_metrics, + background_task=True, +) @validate_transaction_errors(errors=[]) @background_task() def test_async_mode(): @@ -182,16 +206,19 @@ def test_async_mode(): wait = psycopg2cffi.extras.wait_select kwargs = {} - version = tuple(int(_) for _ in psycopg2cffi.__version__.split('.')) + version = tuple(int(_) for _ in psycopg2cffi.__version__.split(".")) if version >= (2, 8): - kwargs['async_'] = 1 + kwargs["async_"] = 1 else: - kwargs['async'] = 1 + kwargs["async"] = 1 async_conn = psycopg2cffi.connect( - database=DB_SETTINGS['name'], user=DB_SETTINGS['user'], - password=DB_SETTINGS['password'], host=DB_SETTINGS['host'], - port=DB_SETTINGS['port'], **kwargs + database=DB_SETTINGS["name"], + user=DB_SETTINGS["user"], + password=DB_SETTINGS["password"], + host=DB_SETTINGS["host"], + port=DB_SETTINGS["port"], + **kwargs ) wait(async_conn) async_cur = async_conn.cursor() @@ -199,12 +226,10 @@ def test_async_mode(): async_cur.execute("""drop table if exists %s""" % DB_SETTINGS["table_name"]) wait(async_cur.connection) - async_cur.execute("""create table %s """ % DB_SETTINGS["table_name"] + - """(a integer, b real, c text)""") + async_cur.execute("""create table %s """ % DB_SETTINGS["table_name"] + """(a integer, b real, c text)""") wait(async_cur.connection) - async_cur.execute("""insert into %s """ % DB_SETTINGS["table_name"] + - """values (%s, %s, %s)""", (1, 1.0, '1.0')) + async_cur.execute("""insert into %s """ % DB_SETTINGS["table_name"] + """values (%s, %s, %s)""", (1, 1.0, "1.0")) wait(async_cur.connection) async_cur.execute("""select * from %s""" % DB_SETTINGS["table_name"]) diff --git a/tests/datastore_pylibmc/test_memcache.py b/tests/datastore_pylibmc/test_memcache.py index 769f3b483..64da33416 100644 --- a/tests/datastore_pylibmc/test_memcache.py +++ b/tests/datastore_pylibmc/test_memcache.py @@ -12,85 +12,92 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - import pylibmc - from testing_support.db_settings import memcached_settings -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.api.background_task import background_task from newrelic.api.transaction import set_background_task - DB_SETTINGS = memcached_settings()[0] MEMCACHED_HOST = DB_SETTINGS["host"] MEMCACHED_PORT = DB_SETTINGS["port"] MEMCACHED_NAMESPACE = DB_SETTINGS["namespace"] -MEMCACHED_ADDR = '%s:%s' % (MEMCACHED_HOST, MEMCACHED_PORT) +MEMCACHED_ADDR = "%s:%s" % (MEMCACHED_HOST, MEMCACHED_PORT) _test_bt_set_get_delete_scoped_metrics = [ - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] _test_bt_set_get_delete_rollup_metrics = [ - ('Datastore/all', 3), - ('Datastore/allOther', 3), - ('Datastore/Memcached/all', 3), - ('Datastore/Memcached/allOther', 3), - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/all", 3), + ("Datastore/allOther", 3), + ("Datastore/Memcached/all", 3), + ("Datastore/Memcached/allOther", 3), + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] + @validate_transaction_metrics( - 'test_memcache:test_bt_set_get_delete', - scoped_metrics=_test_bt_set_get_delete_scoped_metrics, - rollup_metrics=_test_bt_set_get_delete_rollup_metrics, - background_task=True) + "test_memcache:test_bt_set_get_delete", + scoped_metrics=_test_bt_set_get_delete_scoped_metrics, + rollup_metrics=_test_bt_set_get_delete_rollup_metrics, + background_task=True, +) @background_task() def test_bt_set_get_delete(): set_background_task(True) client = pylibmc.Client([MEMCACHED_ADDR]) - key = MEMCACHED_NAMESPACE + 'key' + key = MEMCACHED_NAMESPACE + "key" - client.set(key, 'value') + client.set(key, "value") value = client.get(key) client.delete(key) - assert value == 'value' + assert value == "value" + _test_wt_set_get_delete_scoped_metrics = [ - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] _test_wt_set_get_delete_rollup_metrics = [ - ('Datastore/all', 3), - ('Datastore/allWeb', 3), - ('Datastore/Memcached/all', 3), - ('Datastore/Memcached/allWeb', 3), - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/all", 3), + ("Datastore/allWeb", 3), + ("Datastore/Memcached/all", 3), + ("Datastore/Memcached/allWeb", 3), + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] + @validate_transaction_metrics( - 'test_memcache:test_wt_set_get_delete', - scoped_metrics=_test_wt_set_get_delete_scoped_metrics, - rollup_metrics=_test_wt_set_get_delete_rollup_metrics, - background_task=False) + "test_memcache:test_wt_set_get_delete", + scoped_metrics=_test_wt_set_get_delete_scoped_metrics, + rollup_metrics=_test_wt_set_get_delete_rollup_metrics, + background_task=False, +) @background_task() def test_wt_set_get_delete(): set_background_task(False) client = pylibmc.Client([MEMCACHED_ADDR]) - key = MEMCACHED_NAMESPACE + 'key' + key = MEMCACHED_NAMESPACE + "key" - client.set(key, 'value') + client.set(key, "value") value = client.get(key) client.delete(key) - assert value == 'value' + assert value == "value" diff --git a/tests/datastore_pymemcache/test_memcache.py b/tests/datastore_pymemcache/test_memcache.py index 9aeea4d54..3100db5b7 100644 --- a/tests/datastore_pymemcache/test_memcache.py +++ b/tests/datastore_pymemcache/test_memcache.py @@ -15,9 +15,10 @@ import os import pymemcache.client - -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics from testing_support.db_settings import memcached_settings +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.api.background_task import background_task from newrelic.api.transaction import set_background_task @@ -31,65 +32,74 @@ MEMCACHED_ADDR = (MEMCACHED_HOST, int(MEMCACHED_PORT)) _test_bt_set_get_delete_scoped_metrics = [ - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] _test_bt_set_get_delete_rollup_metrics = [ - ('Datastore/all', 3), - ('Datastore/allOther', 3), - ('Datastore/Memcached/all', 3), - ('Datastore/Memcached/allOther', 3), - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/all", 3), + ("Datastore/allOther", 3), + ("Datastore/Memcached/all", 3), + ("Datastore/Memcached/allOther", 3), + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] + @validate_transaction_metrics( - 'test_memcache:test_bt_set_get_delete', - scoped_metrics=_test_bt_set_get_delete_scoped_metrics, - rollup_metrics=_test_bt_set_get_delete_rollup_metrics, - background_task=True) + "test_memcache:test_bt_set_get_delete", + scoped_metrics=_test_bt_set_get_delete_scoped_metrics, + rollup_metrics=_test_bt_set_get_delete_rollup_metrics, + background_task=True, +) @background_task() def test_bt_set_get_delete(): set_background_task(True) client = pymemcache.client.Client(MEMCACHED_ADDR) - key = MEMCACHED_NAMESPACE + 'key' + key = MEMCACHED_NAMESPACE + "key" - client.set(key, b'value') + client.set(key, b"value") value = client.get(key) client.delete(key) - assert value == b'value' + assert value == b"value" + _test_wt_set_get_delete_scoped_metrics = [ - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] _test_wt_set_get_delete_rollup_metrics = [ - ('Datastore/all', 3), - ('Datastore/allWeb', 3), - ('Datastore/Memcached/all', 3), - ('Datastore/Memcached/allWeb', 3), - ('Datastore/operation/Memcached/set', 1), - ('Datastore/operation/Memcached/get', 1), - ('Datastore/operation/Memcached/delete', 1)] + ("Datastore/all", 3), + ("Datastore/allWeb", 3), + ("Datastore/Memcached/all", 3), + ("Datastore/Memcached/allWeb", 3), + ("Datastore/operation/Memcached/set", 1), + ("Datastore/operation/Memcached/get", 1), + ("Datastore/operation/Memcached/delete", 1), +] + @validate_transaction_metrics( - 'test_memcache:test_wt_set_get_delete', - scoped_metrics=_test_wt_set_get_delete_scoped_metrics, - rollup_metrics=_test_wt_set_get_delete_rollup_metrics, - background_task=False) + "test_memcache:test_wt_set_get_delete", + scoped_metrics=_test_wt_set_get_delete_scoped_metrics, + rollup_metrics=_test_wt_set_get_delete_rollup_metrics, + background_task=False, +) @background_task() def test_wt_set_get_delete(): set_background_task(False) client = pymemcache.client.Client(MEMCACHED_ADDR) - key = MEMCACHED_NAMESPACE + 'key' + key = MEMCACHED_NAMESPACE + "key" - client.set(key, b'value') + client.set(key, b"value") value = client.get(key) client.delete(key) - assert value == b'value' + assert value == b"value" diff --git a/tests/datastore_pymysql/test_database.py b/tests/datastore_pymysql/test_database.py index 5943b1266..ad4db1d9c 100644 --- a/tests/datastore_pymysql/test_database.py +++ b/tests/datastore_pymysql/test_database.py @@ -13,11 +13,14 @@ # limitations under the License. import pymysql - -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_database_trace_inputs import validate_database_trace_inputs - from testing_support.db_settings import mysql_settings +from testing_support.util import instance_hostname +from testing_support.validators.validate_database_trace_inputs import ( + validate_database_trace_inputs, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.api.background_task import background_task @@ -25,76 +28,92 @@ TABLE_NAME = "datastore_pymysql_" + DB_SETTINGS["namespace"] PROCEDURE_NAME = "hello_" + DB_SETTINGS["namespace"] +HOST = instance_hostname(DB_SETTINGS["host"]) +PORT = DB_SETTINGS["port"] + def execute_db_calls_with_cursor(cursor): cursor.execute("""drop table if exists %s""" % TABLE_NAME) - cursor.execute("""create table %s """ % TABLE_NAME + - """(a integer, b real, c text)""") + cursor.execute("""create table %s """ % TABLE_NAME + """(a integer, b real, c text)""") - cursor.executemany("""insert into %s """ % TABLE_NAME + - """values (%s, %s, %s)""", [(1, 1.0, '1.0'), - (2, 2.2, '2.2'), (3, 3.3, '3.3')]) + cursor.executemany( + """insert into %s """ % TABLE_NAME + """values (%s, %s, %s)""", + [(1, 1.0, "1.0"), (2, 2.2, "2.2"), (3, 3.3, "3.3")], + ) cursor.execute("""select * from %s""" % TABLE_NAME) - for row in cursor: pass + for row in cursor: + pass - cursor.execute("""update %s""" % TABLE_NAME + """ set a=%s, b=%s, """ - """c=%s where a=%s""", (4, 4.0, '4.0', 1)) + cursor.execute("""update %s""" % TABLE_NAME + """ set a=%s, b=%s, """ """c=%s where a=%s""", (4, 4.0, "4.0", 1)) cursor.execute("""delete from %s where a=2""" % TABLE_NAME) cursor.execute("""drop procedure if exists %s""" % PROCEDURE_NAME) - cursor.execute("""CREATE PROCEDURE %s() + cursor.execute( + """CREATE PROCEDURE %s() BEGIN SELECT 'Hello World!'; - END""" % PROCEDURE_NAME) + END""" + % PROCEDURE_NAME + ) cursor.callproc(PROCEDURE_NAME) _test_execute_via_cursor_scoped_metrics = [ - ('Function/pymysql:Connect', 1), - ('Datastore/statement/MySQL/%s/select' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/insert' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/update' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/delete' % TABLE_NAME, 1), - ('Datastore/operation/MySQL/drop', 2), - ('Datastore/operation/MySQL/create', 2), - ('Datastore/statement/MySQL/%s/call' % PROCEDURE_NAME, 1), - ('Datastore/operation/MySQL/commit', 2), - ('Datastore/operation/MySQL/rollback', 1)] + ("Function/pymysql:Connect", 1), + ("Datastore/statement/MySQL/%s/select" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/insert" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/update" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/delete" % TABLE_NAME, 1), + ("Datastore/operation/MySQL/drop", 2), + ("Datastore/operation/MySQL/create", 2), + ("Datastore/statement/MySQL/%s/call" % PROCEDURE_NAME, 1), + ("Datastore/operation/MySQL/commit", 2), + ("Datastore/operation/MySQL/rollback", 1), +] _test_execute_via_cursor_rollup_metrics = [ - ('Datastore/all', 13), - ('Datastore/allOther', 13), - ('Datastore/MySQL/all', 13), - ('Datastore/MySQL/allOther', 13), - ('Datastore/statement/MySQL/%s/select' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/insert' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/update' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/delete' % TABLE_NAME, 1), - ('Datastore/operation/MySQL/select', 1), - ('Datastore/operation/MySQL/insert', 1), - ('Datastore/operation/MySQL/update', 1), - ('Datastore/operation/MySQL/delete', 1), - ('Datastore/statement/MySQL/%s/call' % PROCEDURE_NAME, 1), - ('Datastore/operation/MySQL/call', 1), - ('Datastore/operation/MySQL/drop', 2), - ('Datastore/operation/MySQL/create', 2), - ('Datastore/operation/MySQL/commit', 2), - ('Datastore/operation/MySQL/rollback', 1)] - -@validate_transaction_metrics('test_database:test_execute_via_cursor', - scoped_metrics=_test_execute_via_cursor_scoped_metrics, - rollup_metrics=_test_execute_via_cursor_rollup_metrics, - background_task=True) + ("Datastore/all", 13), + ("Datastore/allOther", 13), + ("Datastore/MySQL/all", 13), + ("Datastore/MySQL/allOther", 13), + ("Datastore/statement/MySQL/%s/select" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/insert" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/update" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/delete" % TABLE_NAME, 1), + ("Datastore/operation/MySQL/select", 1), + ("Datastore/operation/MySQL/insert", 1), + ("Datastore/operation/MySQL/update", 1), + ("Datastore/operation/MySQL/delete", 1), + ("Datastore/statement/MySQL/%s/call" % PROCEDURE_NAME, 1), + ("Datastore/operation/MySQL/call", 1), + ("Datastore/operation/MySQL/drop", 2), + ("Datastore/operation/MySQL/create", 2), + ("Datastore/operation/MySQL/commit", 2), + ("Datastore/operation/MySQL/rollback", 1), + ("Datastore/instance/MySQL/%s/%s" % (HOST, PORT), 12), +] + + +@validate_transaction_metrics( + "test_database:test_execute_via_cursor", + scoped_metrics=_test_execute_via_cursor_scoped_metrics, + rollup_metrics=_test_execute_via_cursor_rollup_metrics, + background_task=True, +) @validate_database_trace_inputs(sql_parameters_type=tuple) @background_task() def test_execute_via_cursor(): - connection = pymysql.connect(db=DB_SETTINGS['name'], - user=DB_SETTINGS['user'], passwd=DB_SETTINGS['password'], - host=DB_SETTINGS['host'], port=DB_SETTINGS['port']) + connection = pymysql.connect( + db=DB_SETTINGS["name"], + user=DB_SETTINGS["user"], + passwd=DB_SETTINGS["password"], + host=DB_SETTINGS["host"], + port=DB_SETTINGS["port"], + ) with connection.cursor() as cursor: execute_db_calls_with_cursor(cursor) @@ -105,49 +124,57 @@ def test_execute_via_cursor(): _test_execute_via_cursor_context_mangaer_scoped_metrics = [ - ('Function/pymysql:Connect', 1), - ('Datastore/statement/MySQL/%s/select' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/insert' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/update' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/delete' % TABLE_NAME, 1), - ('Datastore/operation/MySQL/drop', 2), - ('Datastore/operation/MySQL/create', 2), - ('Datastore/statement/MySQL/%s/call' % PROCEDURE_NAME, 1), - ('Datastore/operation/MySQL/commit', 2), - ('Datastore/operation/MySQL/rollback', 1)] + ("Function/pymysql:Connect", 1), + ("Datastore/statement/MySQL/%s/select" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/insert" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/update" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/delete" % TABLE_NAME, 1), + ("Datastore/operation/MySQL/drop", 2), + ("Datastore/operation/MySQL/create", 2), + ("Datastore/statement/MySQL/%s/call" % PROCEDURE_NAME, 1), + ("Datastore/operation/MySQL/commit", 2), + ("Datastore/operation/MySQL/rollback", 1), +] _test_execute_via_cursor_context_mangaer_rollup_metrics = [ - ('Datastore/all', 13), - ('Datastore/allOther', 13), - ('Datastore/MySQL/all', 13), - ('Datastore/MySQL/allOther', 13), - ('Datastore/statement/MySQL/%s/select' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/insert' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/update' % TABLE_NAME, 1), - ('Datastore/statement/MySQL/%s/delete' % TABLE_NAME, 1), - ('Datastore/operation/MySQL/select', 1), - ('Datastore/operation/MySQL/insert', 1), - ('Datastore/operation/MySQL/update', 1), - ('Datastore/operation/MySQL/delete', 1), - ('Datastore/statement/MySQL/%s/call' % PROCEDURE_NAME, 1), - ('Datastore/operation/MySQL/call', 1), - ('Datastore/operation/MySQL/drop', 2), - ('Datastore/operation/MySQL/create', 2), - ('Datastore/operation/MySQL/commit', 2), - ('Datastore/operation/MySQL/rollback', 1)] + ("Datastore/all", 13), + ("Datastore/allOther", 13), + ("Datastore/MySQL/all", 13), + ("Datastore/MySQL/allOther", 13), + ("Datastore/statement/MySQL/%s/select" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/insert" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/update" % TABLE_NAME, 1), + ("Datastore/statement/MySQL/%s/delete" % TABLE_NAME, 1), + ("Datastore/operation/MySQL/select", 1), + ("Datastore/operation/MySQL/insert", 1), + ("Datastore/operation/MySQL/update", 1), + ("Datastore/operation/MySQL/delete", 1), + ("Datastore/statement/MySQL/%s/call" % PROCEDURE_NAME, 1), + ("Datastore/operation/MySQL/call", 1), + ("Datastore/operation/MySQL/drop", 2), + ("Datastore/operation/MySQL/create", 2), + ("Datastore/operation/MySQL/commit", 2), + ("Datastore/operation/MySQL/rollback", 1), + ("Datastore/instance/MySQL/%s/%s" % (HOST, PORT), 12), +] @validate_transaction_metrics( - 'test_database:test_execute_via_cursor_context_manager', - scoped_metrics=_test_execute_via_cursor_context_mangaer_scoped_metrics, - rollup_metrics=_test_execute_via_cursor_context_mangaer_rollup_metrics, - background_task=True) + "test_database:test_execute_via_cursor_context_manager", + scoped_metrics=_test_execute_via_cursor_context_mangaer_scoped_metrics, + rollup_metrics=_test_execute_via_cursor_context_mangaer_rollup_metrics, + background_task=True, +) @validate_database_trace_inputs(sql_parameters_type=tuple) @background_task() def test_execute_via_cursor_context_manager(): - connection = pymysql.connect(db=DB_SETTINGS['name'], - user=DB_SETTINGS['user'], passwd=DB_SETTINGS['password'], - host=DB_SETTINGS['host'], port=DB_SETTINGS['port']) + connection = pymysql.connect( + db=DB_SETTINGS["name"], + user=DB_SETTINGS["user"], + passwd=DB_SETTINGS["password"], + host=DB_SETTINGS["host"], + port=DB_SETTINGS["port"], + ) cursor = connection.cursor() with cursor: diff --git a/tests/datastore_pyodbc/test_pyodbc.py b/tests/datastore_pyodbc/test_pyodbc.py index 119908e4d..5a810be5f 100644 --- a/tests/datastore_pyodbc/test_pyodbc.py +++ b/tests/datastore_pyodbc/test_pyodbc.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest from testing_support.db_settings import postgresql_settings +from testing_support.util import instance_hostname from testing_support.validators.validate_database_trace_inputs import ( validate_database_trace_inputs, ) diff --git a/tests/datastore_pysolr/test_solr.py b/tests/datastore_pysolr/test_solr.py index a987a29ac..e17117117 100644 --- a/tests/datastore_pysolr/test_solr.py +++ b/tests/datastore_pysolr/test_solr.py @@ -13,16 +13,19 @@ # limitations under the License. from pysolr import Solr - -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics from testing_support.db_settings import solr_settings +from testing_support.util import instance_hostname +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.api.background_task import background_task DB_SETTINGS = solr_settings()[0] SOLR_HOST = DB_SETTINGS["host"] SOLR_PORT = DB_SETTINGS["port"] -SOLR_URL = 'http://%s:%s/solr/collection' % (DB_SETTINGS["host"], DB_SETTINGS["port"]) +SOLR_URL = "http://%s:%s/solr/collection" % (DB_SETTINGS["host"], DB_SETTINGS["port"]) + def _exercise_solr(solr): # Construct document names within namespace @@ -31,30 +34,36 @@ def _exercise_solr(solr): solr.add([{"id": x} for x in documents]) - solr.search('id:%s' % documents[0]) + solr.search("id:%s" % documents[0]) solr.delete(id=documents[0]) # Delete all documents. - solr.delete(q='id:*_%s' % DB_SETTINGS["namespace"]) + solr.delete(q="id:*_%s" % DB_SETTINGS["namespace"]) + _test_solr_search_scoped_metrics = [ - ('Datastore/operation/Solr/add', 1), - ('Datastore/operation/Solr/delete', 2), - ('Datastore/operation/Solr/search', 1)] + ("Datastore/operation/Solr/add", 1), + ("Datastore/operation/Solr/delete", 2), + ("Datastore/operation/Solr/search", 1), +] _test_solr_search_rollup_metrics = [ - ('Datastore/all', 4), - ('Datastore/allOther', 4), - ('Datastore/Solr/all', 4), - ('Datastore/Solr/allOther', 4), - ('Datastore/operation/Solr/add', 1), - ('Datastore/operation/Solr/search', 1), - ('Datastore/operation/Solr/delete', 2)] - -@validate_transaction_metrics('test_solr:test_solr_search', + ("Datastore/all", 4), + ("Datastore/allOther", 4), + ("Datastore/Solr/all", 4), + ("Datastore/Solr/allOther", 4), + ("Datastore/operation/Solr/add", 1), + ("Datastore/operation/Solr/search", 1), + ("Datastore/operation/Solr/delete", 2), +] + + +@validate_transaction_metrics( + "test_solr:test_solr_search", scoped_metrics=_test_solr_search_scoped_metrics, rollup_metrics=_test_solr_search_rollup_metrics, - background_task=True) + background_task=True, +) @background_task() def test_solr_search(): s = Solr(SOLR_URL) diff --git a/tests/datastore_redis/test_asyncio.py b/tests/datastore_redis/test_asyncio.py index f17d6b2e0..f46e8515e 100644 --- a/tests/datastore_redis/test_asyncio.py +++ b/tests/datastore_redis/test_asyncio.py @@ -30,7 +30,9 @@ DB_SETTINGS = redis_settings()[0] REDIS_PY_VERSION = get_package_version_tuple("redis") -# Metrics +# Metrics for publish test + +datastore_all_metric_count = 5 if REDIS_PY_VERSION >= (5, 0) else 3 _base_scoped_metrics = [("Datastore/operation/Redis/publish", 3)] @@ -39,8 +41,6 @@ ("Datastore/operation/Redis/client_setinfo", 2), ) -datastore_all_metric_count = 5 if REDIS_PY_VERSION >= (5, 0) else 3 - _base_rollup_metrics = [ ("Datastore/all", datastore_all_metric_count), ("Datastore/allOther", datastore_all_metric_count), @@ -57,6 +57,27 @@ ("Datastore/operation/Redis/client_setinfo", 2), ) + +# Metrics for connection pool test + +_base_pool_scoped_metrics = [ + ("Datastore/operation/Redis/get", 1), + ("Datastore/operation/Redis/set", 1), + ("Datastore/operation/Redis/client_list", 1), +] + +_base_pool_rollup_metrics = [ + ("Datastore/all", 3), + ("Datastore/allOther", 3), + ("Datastore/Redis/all", 3), + ("Datastore/Redis/allOther", 3), + ("Datastore/operation/Redis/get", 1), + ("Datastore/operation/Redis/set", 1), + ("Datastore/operation/Redis/client_list", 1), + ("Datastore/instance/Redis/%s/%s" % (instance_hostname(DB_SETTINGS["host"]), DB_SETTINGS["port"]), 3), +] + + # Tests @@ -67,6 +88,31 @@ def client(loop): # noqa return loop.run_until_complete(redis.asyncio.Redis(host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], db=0)) +@pytest.fixture() +def client_pool(loop): # noqa + import redis.asyncio + + connection_pool = redis.asyncio.ConnectionPool(host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], db=0) + return loop.run_until_complete(redis.asyncio.Redis(connection_pool=connection_pool)) + + +@pytest.mark.skipif(REDIS_PY_VERSION < (4, 2), reason="This functionality exists in Redis 4.2+") +@validate_transaction_metrics( + "test_asyncio:test_async_connection_pool", + scoped_metrics=_base_pool_scoped_metrics, + rollup_metrics=_base_pool_rollup_metrics, + background_task=True, +) +@background_task() +def test_async_connection_pool(client_pool, loop): # noqa + async def _test_async_pool(client_pool): + await client_pool.set("key1", "value1") + await client_pool.get("key1") + await client_pool.execute_command("CLIENT", "LIST") + + loop.run_until_complete(_test_async_pool(client_pool)) + + @pytest.mark.skipif(REDIS_PY_VERSION < (4, 2), reason="This functionality exists in Redis 4.2+") @validate_transaction_metrics("test_asyncio:test_async_pipeline", background_task=True) @background_task() diff --git a/tests/datastore_solrpy/test_solr.py b/tests/datastore_solrpy/test_solr.py index ee1a7e91e..56dcce62b 100644 --- a/tests/datastore_solrpy/test_solr.py +++ b/tests/datastore_solrpy/test_solr.py @@ -13,16 +13,19 @@ # limitations under the License. from solr import SolrConnection - -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics from testing_support.db_settings import solr_settings +from testing_support.util import instance_hostname +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) from newrelic.api.background_task import background_task DB_SETTINGS = solr_settings()[0] SOLR_HOST = DB_SETTINGS["host"] SOLR_PORT = DB_SETTINGS["port"] -SOLR_URL = 'http://%s:%s/solr/collection' % (DB_SETTINGS["host"], DB_SETTINGS["port"]) +SOLR_URL = "http://%s:%s/solr/collection" % (DB_SETTINGS["host"], DB_SETTINGS["port"]) + def _exercise_solr(solr): # Construct document names within namespace @@ -31,30 +34,37 @@ def _exercise_solr(solr): solr.add_many([{"id": x} for x in documents]) solr.commit() - solr.query('id:%s' % documents[0]).results - solr.delete('id:*_%s' % DB_SETTINGS["namespace"]) + solr.query("id:%s" % documents[0]).results + solr.delete("id:*_%s" % DB_SETTINGS["namespace"]) solr.commit() + _test_solr_search_scoped_metrics = [ - ('Datastore/operation/Solr/add_many', 1), - ('Datastore/operation/Solr/delete', 1), - ('Datastore/operation/Solr/commit', 2), - ('Datastore/operation/Solr/query', 1)] + ("Datastore/operation/Solr/add_many", 1), + ("Datastore/operation/Solr/delete", 1), + ("Datastore/operation/Solr/commit", 2), + ("Datastore/operation/Solr/query", 1), +] _test_solr_search_rollup_metrics = [ - ('Datastore/all', 5), - ('Datastore/allOther', 5), - ('Datastore/Solr/all', 5), - ('Datastore/Solr/allOther', 5), - ('Datastore/operation/Solr/add_many', 1), - ('Datastore/operation/Solr/query', 1), - ('Datastore/operation/Solr/commit', 2), - ('Datastore/operation/Solr/delete', 1)] - -@validate_transaction_metrics('test_solr:test_solr_search', + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ("Datastore/Solr/all", 5), + ("Datastore/Solr/allOther", 5), + ("Datastore/instance/Solr/%s/%s" % (instance_hostname(SOLR_HOST), SOLR_PORT), 3), + ("Datastore/operation/Solr/add_many", 1), + ("Datastore/operation/Solr/query", 1), + ("Datastore/operation/Solr/commit", 2), + ("Datastore/operation/Solr/delete", 1), +] + + +@validate_transaction_metrics( + "test_solr:test_solr_search", scoped_metrics=_test_solr_search_scoped_metrics, rollup_metrics=_test_solr_search_rollup_metrics, - background_task=True) + background_task=True, +) @background_task() def test_solr_search(): s = SolrConnection(SOLR_URL) diff --git a/tests/framework_ariadne/__init__.py b/tests/framework_ariadne/__init__.py new file mode 100644 index 000000000..8030baccf --- /dev/null +++ b/tests/framework_ariadne/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/framework_ariadne/_target_application.py b/tests/framework_ariadne/_target_application.py index 94bc0710f..fef782608 100644 --- a/tests/framework_ariadne/_target_application.py +++ b/tests/framework_ariadne/_target_application.py @@ -12,140 +12,125 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - -from ariadne import ( - MutationType, - QueryType, - UnionType, - load_schema_from_path, - make_executable_schema, + +import asyncio +import json + +from framework_ariadne._target_schema_async import ( + target_asgi_application as target_asgi_application_async, +) +from framework_ariadne._target_schema_async import target_schema as target_schema_async +from framework_ariadne._target_schema_sync import ( + target_asgi_application as target_asgi_application_sync, +) +from framework_ariadne._target_schema_sync import target_schema as target_schema_sync +from framework_ariadne._target_schema_sync import ( + target_wsgi_application as target_wsgi_application_sync, ) -from ariadne.asgi import GraphQL as GraphQLASGI -from ariadne.wsgi import GraphQL as GraphQLWSGI +from framework_ariadne._target_schema_sync import ariadne_version_tuple +from graphql import MiddlewareManager -schema_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "schema.graphql") -type_defs = load_schema_from_path(schema_file) - - -authors = [ - { - "first_name": "New", - "last_name": "Relic", - }, - { - "first_name": "Bob", - "last_name": "Smith", - }, - { - "first_name": "Leslie", - "last_name": "Jones", - }, -] -books = [ - { - "id": 1, - "name": "Python Agent: The Book", - "isbn": "a-fake-isbn", - "author": authors[0], - "branch": "riverside", - }, - { - "id": 2, - "name": "Ollies for O11y: A Sk8er's Guide to Observability", - "isbn": "a-second-fake-isbn", - "author": authors[1], - "branch": "downtown", - }, - { - "id": 3, - "name": "[Redacted]", - "isbn": "a-third-fake-isbn", - "author": authors[2], - "branch": "riverside", - }, -] -magazines = [ - {"id": 1, "name": "Reli Updates Weekly", "issue": 1, "branch": "riverside"}, - {"id": 2, "name": "Reli Updates Weekly", "issue": 2, "branch": "downtown"}, - {"id": 3, "name": "Node Weekly", "issue": 1, "branch": "riverside"}, -] +def check_response(query, success, response): + if isinstance(query, str) and "error" not in query: + assert success and "errors" not in response, response + assert response.get("data", None), response + else: + assert "errors" in response, response -libraries = ["riverside", "downtown"] -libraries = [ - { - "id": i + 1, - "branch": branch, - "magazine": [m for m in magazines if m["branch"] == branch], - "book": [b for b in books if b["branch"] == branch], - } - for i, branch in enumerate(libraries) -] +def run_sync(schema): + def _run_sync(query, middleware=None): + from ariadne import graphql_sync -storage = [] + if ariadne_version_tuple < (0, 18): + if middleware: + middleware = MiddlewareManager(*middleware) + success, response = graphql_sync(schema, {"query": query}, middleware=middleware) + check_response(query, success, response) -mutation = MutationType() + return response.get("data", {}) + return _run_sync -@mutation.field("storage_add") -def mutate(self, info, string): - storage.append(string) - return {"string": string} +def run_async(schema): + def _run_async(query, middleware=None): + from ariadne import graphql -item = UnionType("Item") + #Later versions of ariadne directly accept a list of middleware while older versions require the MiddlewareManager + if ariadne_version_tuple < (0, 18): + if middleware: + middleware = MiddlewareManager(*middleware) + loop = asyncio.get_event_loop() + success, response = loop.run_until_complete(graphql(schema, {"query": query}, middleware=middleware)) + check_response(query, success, response) -@item.type_resolver -def resolve_type(obj, *args): - if "isbn" in obj: - return "Book" - elif "issue" in obj: # pylint: disable=R1705 - return "Magazine" + return response.get("data", {}) - return None + return _run_async -query = QueryType() +def run_wsgi(app): + def _run_asgi(query, middleware=None): + if not isinstance(query, str) or "error" in query: + expect_errors = True + else: + expect_errors = False + app.app.middleware = middleware -@query.field("library") -def resolve_library(self, info, index): - return libraries[index] + response = app.post( + "/", json.dumps({"query": query}), headers={"Content-Type": "application/json"}, expect_errors=expect_errors + ) + body = json.loads(response.body.decode("utf-8")) + if expect_errors: + assert body["errors"] + else: + assert "errors" not in body or not body["errors"] -@query.field("storage") -def resolve_storage(self, info): - return storage + return body.get("data", {}) + return _run_asgi -@query.field("search") -def resolve_search(self, info, contains): - search_books = [b for b in books if contains in b["name"]] - search_magazines = [m for m in magazines if contains in m["name"]] - return search_books + search_magazines +def run_asgi(app): + def _run_asgi(query, middleware=None): + if ariadne_version_tuple < (0, 16): + app.asgi_application.middleware = middleware -@query.field("hello") -def resolve_hello(self, info): - return "Hello!" + #In ariadne v0.16.0, the middleware attribute was removed from the GraphQL class in favor of the http_handler + elif ariadne_version_tuple >= (0, 16): + app.asgi_application.http_handler.middleware = middleware + response = app.make_request( + "POST", "/", body=json.dumps({"query": query}), headers={"Content-Type": "application/json"} + ) + body = json.loads(response.body.decode("utf-8")) -@query.field("echo") -def resolve_echo(self, info, echo): - return echo + if not isinstance(query, str) or "error" in query: + try: + assert response.status != 200 + except AssertionError: + assert body["errors"] + else: + assert response.status == 200 + assert "errors" not in body or not body["errors"] + return body.get("data", {}) -@query.field("error_non_null") -@query.field("error") -def resolve_error(self, info): - raise RuntimeError("Runtime Error!") + return _run_asgi -_target_application = make_executable_schema(type_defs, query, mutation, item) -_target_asgi_application = GraphQLASGI(_target_application) -_target_wsgi_application = GraphQLWSGI(_target_application) +target_application = { + "sync-sync": run_sync(target_schema_sync), + "async-sync": run_async(target_schema_sync), + "async-async": run_async(target_schema_async), + "wsgi-sync": run_wsgi(target_wsgi_application_sync), + "asgi-sync": run_asgi(target_asgi_application_sync), + "asgi-async": run_asgi(target_asgi_application_async), +} diff --git a/tests/framework_ariadne/_target_schema_async.py b/tests/framework_ariadne/_target_schema_async.py new file mode 100644 index 000000000..076475628 --- /dev/null +++ b/tests/framework_ariadne/_target_schema_async.py @@ -0,0 +1,94 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from ariadne import ( + MutationType, + QueryType, + UnionType, + load_schema_from_path, + make_executable_schema, +) +from ariadne.asgi import GraphQL as GraphQLASGI +from framework_graphql._target_schema_sync import books, magazines, libraries + +from testing_support.asgi_testing import AsgiTest + +schema_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "schema.graphql") +type_defs = load_schema_from_path(schema_file) + +storage = [] + +mutation = MutationType() + + +@mutation.field("storage_add") +async def resolve_storage_add(self, info, string): + storage.append(string) + return string + + +item = UnionType("Item") + + +@item.type_resolver +async def resolve_type(obj, *args): + if "isbn" in obj: + return "Book" + elif "issue" in obj: # pylint: disable=R1705 + return "Magazine" + + return None + + +query = QueryType() + + +@query.field("library") +async def resolve_library(self, info, index): + return libraries[index] + + +@query.field("storage") +async def resolve_storage(self, info): + return [storage.pop()] + + +@query.field("search") +async def resolve_search(self, info, contains): + search_books = [b for b in books if contains in b["name"]] + search_magazines = [m for m in magazines if contains in m["name"]] + return search_books + search_magazines + + +@query.field("hello") +@query.field("error_middleware") +async def resolve_hello(self, info): + return "Hello!" + + +@query.field("echo") +async def resolve_echo(self, info, echo): + return echo + + +@query.field("error_non_null") +@query.field("error") +async def resolve_error(self, info): + raise RuntimeError("Runtime Error!") + + +target_schema = make_executable_schema(type_defs, query, mutation, item) +target_asgi_application = AsgiTest(GraphQLASGI(target_schema)) diff --git a/tests/framework_ariadne/_target_schema_sync.py b/tests/framework_ariadne/_target_schema_sync.py new file mode 100644 index 000000000..8860e71ac --- /dev/null +++ b/tests/framework_ariadne/_target_schema_sync.py @@ -0,0 +1,106 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import webtest + +from ariadne import ( + MutationType, + QueryType, + UnionType, + load_schema_from_path, + make_executable_schema, +) +from ariadne.wsgi import GraphQL as GraphQLWSGI +from framework_graphql._target_schema_sync import books, magazines, libraries + +from testing_support.asgi_testing import AsgiTest +from framework_ariadne.test_application import ARIADNE_VERSION + +ariadne_version_tuple = tuple(map(int, ARIADNE_VERSION.split("."))) + +if ariadne_version_tuple < (0, 16): + from ariadne.asgi import GraphQL as GraphQLASGI +elif ariadne_version_tuple >= (0, 16): + from ariadne.asgi.graphql import GraphQL as GraphQLASGI + + +schema_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "schema.graphql") +type_defs = load_schema_from_path(schema_file) + +storage = [] + +mutation = MutationType() + + + +@mutation.field("storage_add") +def resolve_storage_add(self, info, string): + storage.append(string) + return string + + +item = UnionType("Item") + + +@item.type_resolver +def resolve_type(obj, *args): + if "isbn" in obj: + return "Book" + elif "issue" in obj: # pylint: disable=R1705 + return "Magazine" + + return None + + +query = QueryType() + + +@query.field("library") +def resolve_library(self, info, index): + return libraries[index] + + +@query.field("storage") +def resolve_storage(self, info): + return [storage.pop()] + + +@query.field("search") +def resolve_search(self, info, contains): + search_books = [b for b in books if contains in b["name"]] + search_magazines = [m for m in magazines if contains in m["name"]] + return search_books + search_magazines + + +@query.field("hello") +@query.field("error_middleware") +def resolve_hello(self, info): + return "Hello!" + + +@query.field("echo") +def resolve_echo(self, info, echo): + return echo + + +@query.field("error_non_null") +@query.field("error") +def resolve_error(self, info): + raise RuntimeError("Runtime Error!") + + +target_schema = make_executable_schema(type_defs, query, mutation, item) +target_asgi_application = AsgiTest(GraphQLASGI(target_schema)) +target_wsgi_application = webtest.TestApp(GraphQLWSGI(target_schema)) \ No newline at end of file diff --git a/tests/framework_ariadne/conftest.py b/tests/framework_ariadne/conftest.py index 93623a685..42b08faba 100644 --- a/tests/framework_ariadne/conftest.py +++ b/tests/framework_ariadne/conftest.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import six -from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture # noqa: F401; pylint: disable=W0611 - +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) _default_settings = { "transaction_tracer.explain_threshold": 0.0, @@ -31,12 +32,5 @@ ) -@pytest.fixture(scope="session") -def app(): - from _target_application import _target_application - - return _target_application - - if six.PY2: collect_ignore = ["test_application_async.py"] diff --git a/tests/framework_ariadne/schema.graphql b/tests/framework_ariadne/schema.graphql index 4c76e0b88..8bf64af51 100644 --- a/tests/framework_ariadne/schema.graphql +++ b/tests/framework_ariadne/schema.graphql @@ -33,7 +33,7 @@ type Magazine { } type Mutation { - storage_add(string: String!): StorageAdd + storage_add(string: String!): String } type Query { @@ -44,8 +44,5 @@ type Query { echo(echo: String!): String error: String error_non_null: String! -} - -type StorageAdd { - string: String + error_middleware: String } diff --git a/tests/framework_ariadne/test_application.py b/tests/framework_ariadne/test_application.py index cf8501a7a..0b7bf2489 100644 --- a/tests/framework_ariadne/test_application.py +++ b/tests/framework_ariadne/test_application.py @@ -11,526 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import pytest -from testing_support.fixtures import dt_enabled, override_application_settings -from testing_support.validators.validate_span_events import validate_span_events -from testing_support.validators.validate_transaction_count import ( - validate_transaction_count, -) -from testing_support.validators.validate_transaction_errors import ( - validate_transaction_errors, -) -from testing_support.validators.validate_transaction_metrics import ( - validate_transaction_metrics, -) - -from newrelic.api.background_task import background_task -from newrelic.common.object_names import callable_name -from newrelic.common.package_version_utils import get_package_version_tuple - - -@pytest.fixture(scope="session") -def is_graphql_2(): - from graphql import __version__ as version - - major_version = int(version.split(".")[0]) - return major_version == 2 - - -@pytest.fixture(scope="session") -def graphql_run(): - """Wrapper function to simulate framework_graphql test behavior.""" - - def execute(schema, query, *args, **kwargs): - from ariadne import graphql_sync - - return graphql_sync(schema, {"query": query}, *args, **kwargs) - - return execute - - -def to_graphql_source(query): - def delay_import(): - try: - from graphql import Source - except ImportError: - # Fallback if Source is not implemented - return query - - from graphql import __version__ as version - - # For graphql2, Source objects aren't acceptable input - major_version = int(version.split(".")[0]) - if major_version == 2: - return query - - return Source(query) - - return delay_import - - -def example_middleware(next, root, info, **args): # pylint: disable=W0622 - return_value = next(root, info, **args) - return return_value - - -def error_middleware(next, root, info, **args): # pylint: disable=W0622 - raise RuntimeError("Runtime Error!") - - -_runtime_error_name = callable_name(RuntimeError) -_test_runtime_error = [(_runtime_error_name, "Runtime Error!")] -_graphql_base_rollup_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 1), - ("GraphQL/allOther", 1), - ("GraphQL/Ariadne/all", 1), - ("GraphQL/Ariadne/allOther", 1), -] - - -def test_basic(app, graphql_run): - from graphql import __version__ as version - - FRAMEWORK_METRICS = [ - ("Python/Framework/Ariadne/None", 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - - @validate_transaction_metrics( - "query//hello", - "GraphQL", - rollup_metrics=_graphql_base_rollup_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @background_task() - def _test(): - ok, response = graphql_run(app, "{ hello }") - assert ok and not response.get("errors") - - _test() - - -@dt_enabled -def test_query_and_mutation(app, graphql_run): - from graphql import __version__ as version - - FRAMEWORK_METRICS = [ - ("Python/Framework/Ariadne/None", 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/Ariadne/storage", 1), - ("GraphQL/resolve/Ariadne/storage_add", 1), - ("GraphQL/operation/Ariadne/query//storage", 1), - ("GraphQL/operation/Ariadne/mutation//storage_add.string", 1), - ] - _test_mutation_unscoped_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 2), - ("GraphQL/Ariadne/all", 2), - ("GraphQL/allOther", 2), - ("GraphQL/Ariadne/allOther", 2), - ] + _test_mutation_scoped_metrics - - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "StorageAdd", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String]", - } - - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @background_task() - def _test(): - ok, response = graphql_run(app, 'mutation { storage_add(string: "abc") { string } }') - assert ok and not response.get("errors") - ok, response = graphql_run(app, "query { storage }") - assert ok and not response.get("errors") - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response["data"]) - assert "abc" in str(response["data"]) - - _test() - - -@dt_enabled -def test_middleware(app, graphql_run, is_graphql_2): - _test_middleware_metrics = [ - ("GraphQL/operation/Ariadne/query//hello", 1), - ("GraphQL/resolve/Ariadne/hello", 1), - ("Function/test_application:example_middleware", 1), - ] - - @validate_transaction_metrics( - "query//hello", - "GraphQL", - scoped_metrics=_test_middleware_metrics, - rollup_metrics=_test_middleware_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - # Span count 5: Transaction, Operation, Middleware, and 1 Resolver and Resolver function - @validate_span_events(count=5) - @background_task() - def _test(): - from graphql import MiddlewareManager - - middleware = ( - [example_middleware] - if get_package_version_tuple("ariadne") >= (0, 18) - else MiddlewareManager(example_middleware) - ) +from framework_graphql.test_application import * - ok, response = graphql_run(app, "{ hello }", middleware=middleware) - assert ok and not response.get("errors") - assert "Hello!" in str(response["data"]) +from newrelic.common.package_version_utils import get_package_version - _test() +ARIADNE_VERSION = get_package_version("ariadne") +ariadne_version_tuple = tuple(map(int, ARIADNE_VERSION.split("."))) -@dt_enabled -def test_exception_in_middleware(app, graphql_run): - query = "query MyQuery { hello }" - field = "hello" - - # Metrics - _test_exception_scoped_metrics = [ - ("GraphQL/operation/Ariadne/query/MyQuery/%s" % field, 1), - ("GraphQL/resolve/Ariadne/%s" % field, 1), - ] - _test_exception_rollup_metrics = [ - ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/test_application:error_middleware", 1), - ] + _test_exception_scoped_metrics - - # Attributes - _expected_exception_resolver_attributes = { - "graphql.field.name": field, - "graphql.field.parentType": "Query", - "graphql.field.path": field, - "graphql.field.returnType": "String", - } - _expected_exception_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "MyQuery", - "graphql.operation.query": query, - } - - @validate_transaction_metrics( - "test_application:error_middleware", - "GraphQL", - scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_exception_operation_attributes) - @validate_span_events(exact_agents=_expected_exception_resolver_attributes) - @validate_transaction_errors(errors=_test_runtime_error) - @background_task() - def _test(): - from graphql import MiddlewareManager - - middleware = ( - [error_middleware] - if get_package_version_tuple("ariadne") >= (0, 18) - else MiddlewareManager(error_middleware) - ) - - _, response = graphql_run(app, query, middleware=middleware) - assert response["errors"] - - _test() - - -@pytest.mark.parametrize("field", ("error", "error_non_null")) -@dt_enabled -def test_exception_in_resolver(app, graphql_run, field): - query = "query MyQuery { %s }" % field - txn_name = "_target_application:resolve_error" - - # Metrics - _test_exception_scoped_metrics = [ - ("GraphQL/operation/Ariadne/query/MyQuery/%s" % field, 1), - ("GraphQL/resolve/Ariadne/%s" % field, 1), - ] - _test_exception_rollup_metrics = [ - ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/%s" % txn_name, 1), - ] + _test_exception_scoped_metrics - - # Attributes - _expected_exception_resolver_attributes = { - "graphql.field.name": field, - "graphql.field.parentType": "Query", - "graphql.field.path": field, - "graphql.field.returnType": "String!" if "non_null" in field else "String", - } - _expected_exception_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "MyQuery", - "graphql.operation.query": query, - } - - @validate_transaction_metrics( - txn_name, - "GraphQL", - scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_exception_operation_attributes) - @validate_span_events(exact_agents=_expected_exception_resolver_attributes) - @validate_transaction_errors(errors=_test_runtime_error) - @background_task() - def _test(): - _, response = graphql_run(app, query) - assert response["errors"] - - _test() - - -@dt_enabled -@pytest.mark.parametrize( - "query,exc_class", - [ - ("query MyQuery { missing_field }", "GraphQLError"), - ("{ syntax_error ", "graphql.error.syntax_error:GraphQLSyntaxError"), - ], +@pytest.fixture( + scope="session", params=["sync-sync", "async-sync", "async-async", "wsgi-sync", "asgi-sync", "asgi-async"] ) -def test_exception_in_validation(app, graphql_run, is_graphql_2, query, exc_class): - if "syntax" in query: - txn_name = "graphql.language.parser:parse" - else: - if is_graphql_2: - txn_name = "graphql.validation.validation:validate" - else: - txn_name = "graphql.validation.validate:validate" - - # Import path differs between versions - if exc_class == "GraphQLError": - from graphql.error import GraphQLError - - exc_class = callable_name(GraphQLError) - - _test_exception_scoped_metrics = [ - ("GraphQL/operation/Ariadne///", 1), - ] - _test_exception_rollup_metrics = [ - ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/%s" % txn_name, 1), - ] + _test_exception_scoped_metrics - - # Attributes - _expected_exception_operation_attributes = { - "graphql.operation.type": "", - "graphql.operation.name": "", - "graphql.operation.query": query, - } - - @validate_transaction_metrics( - txn_name, - "GraphQL", - scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_exception_operation_attributes) - @validate_transaction_errors(errors=[exc_class]) - @background_task() - def _test(): - _, response = graphql_run(app, query) - assert response["errors"] - - _test() - - -@dt_enabled -def test_operation_metrics_and_attrs(app, graphql_run): - operation_metrics = [("GraphQL/operation/Ariadne/query/MyQuery/library", 1)] - operation_attrs = { - "graphql.operation.type": "query", - "graphql.operation.name": "MyQuery", - } - - @validate_transaction_metrics( - "query/MyQuery/library", - "GraphQL", - scoped_metrics=operation_metrics, - rollup_metrics=operation_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - # Span count 16: Transaction, Operation, and 7 Resolvers and Resolver functions - # library, library.name, library.book - # library.book.name and library.book.id for each book resolved (in this case 2) - @validate_span_events(count=16) - @validate_span_events(exact_agents=operation_attrs) - @background_task() - def _test(): - ok, response = graphql_run(app, "query MyQuery { library(index: 0) { branch, book { id, name } } }") - assert ok and not response.get("errors") - - _test() - - -@dt_enabled -def test_field_resolver_metrics_and_attrs(app, graphql_run): - field_resolver_metrics = [("GraphQL/resolve/Ariadne/hello", 1)] - graphql_attrs = { - "graphql.field.name": "hello", - "graphql.field.parentType": "Query", - "graphql.field.path": "hello", - "graphql.field.returnType": "String", - } - - @validate_transaction_metrics( - "query//hello", - "GraphQL", - scoped_metrics=field_resolver_metrics, - rollup_metrics=field_resolver_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - # Span count 4: Transaction, Operation, and 1 Resolver and Resolver function - @validate_span_events(count=4) - @validate_span_events(exact_agents=graphql_attrs) - @background_task() - def _test(): - ok, response = graphql_run(app, "{ hello }") - assert ok and not response.get("errors") - assert "Hello!" in str(response["data"]) - - _test() - - -_test_queries = [ - ("{ hello }", "{ hello }"), # Basic query extraction - ("{ error }", "{ error }"), # Extract query on field error - ("{ library(index: 0) { branch } }", "{ library(index: ?) { branch } }"), # Integers - ('{ echo(echo: "123") }', "{ echo(echo: ?) }"), # Strings with numerics - ('{ echo(echo: "test") }', "{ echo(echo: ?) }"), # Strings - ('{ TestEcho: echo(echo: "test") }', "{ TestEcho: echo(echo: ?) }"), # Aliases - ('{ TestEcho: echo(echo: "test") }', "{ TestEcho: echo(echo: ?) }"), # Variables - ( # Fragments - '{ ...MyFragment } fragment MyFragment on Query { echo(echo: "test") }', - "{ ...MyFragment } fragment MyFragment on Query { echo(echo: ?) }", - ), -] - - -@dt_enabled -@pytest.mark.parametrize("query,obfuscated", _test_queries) -def test_query_obfuscation(app, graphql_run, query, obfuscated): - graphql_attrs = {"graphql.operation.query": obfuscated} - - @validate_span_events(exact_agents=graphql_attrs) - @background_task() - def _test(): - ok, response = graphql_run(app, query) - if not isinstance(query, str) or "error" not in query: - assert ok and not response.get("errors") - - _test() - - -_test_queries = [ - ("{ hello }", "/hello"), # Basic query - ("{ error }", "/error"), # Extract deepest path on field error - ('{ echo(echo: "test") }', "/echo"), # Fields with arguments - ( - "{ library(index: 0) { branch, book { isbn branch } } }", - "/library", - ), # Complex Example, 1 level - ( - "{ library(index: 0) { book { author { first_name }} } }", - "/library.book.author.first_name", - ), # Complex Example, 2 levels - ("{ library(index: 0) { id, book { name } } }", "/library.book.name"), # Filtering - ('{ TestEcho: echo(echo: "test") }', "/echo"), # Aliases - ( - '{ search(contains: "A") { __typename ... on Book { name } } }', - "/search.name", - ), # InlineFragment - ( - '{ hello echo(echo: "test") }', - "", - ), # Multiple root selections. (need to decide on final behavior) - # FragmentSpread - ( - "{ library(index: 0) { book { ...MyFragment } } } fragment MyFragment on Book { name id }", # Fragment filtering - "/library.book.name", - ), - ( - "{ library(index: 0) { book { ...MyFragment } } } fragment MyFragment on Book { author { first_name } }", - "/library.book.author.first_name", - ), - ( - "{ library(index: 0) { book { ...MyFragment } magazine { ...MagFragment } } } fragment MyFragment on Book { author { first_name } } fragment MagFragment on Magazine { name }", - "/library", - ), -] - - -@dt_enabled -@pytest.mark.parametrize("query,expected_path", _test_queries) -def test_deepest_unique_path(app, graphql_run, query, expected_path): - if expected_path == "/error": - txn_name = "_target_application:resolve_error" - else: - txn_name = "query/%s" % expected_path - - @validate_transaction_metrics( - txn_name, - "GraphQL", - background_task=True, - ) - @background_task() - def _test(): - ok, response = graphql_run(app, query) - if "error" not in query: - assert ok and not response.get("errors") - - _test() - +def target_application(request): + from ._target_application import target_application -@pytest.mark.parametrize("capture_introspection_setting", (True, False)) -def test_introspection_transactions(app, graphql_run, capture_introspection_setting): - txn_ct = 1 if capture_introspection_setting else 0 + target_application = target_application[request.param] - @override_application_settings( - {"instrumentation.graphql.capture_introspection_queries": capture_introspection_setting} - ) - @validate_transaction_count(txn_ct) - @background_task() - def _test(): - ok, response = graphql_run(app, "{ __schema { types { name } } }") - assert ok and not response.get("errors") + param = request.param.split("-") + is_background = param[0] not in {"wsgi", "asgi"} + schema_type = param[1] + extra_spans = 4 if param[0] == "wsgi" else 0 - _test() + assert ARIADNE_VERSION is not None + return "Ariadne", ARIADNE_VERSION, target_application, is_background, schema_type, extra_spans diff --git a/tests/framework_ariadne/test_application_async.py b/tests/framework_ariadne/test_application_async.py deleted file mode 100644 index ada34ffad..000000000 --- a/tests/framework_ariadne/test_application_async.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio - -import pytest -from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_span_events import validate_span_events - -from newrelic.api.background_task import background_task - - -@pytest.fixture(scope="session") -def graphql_run_async(): - """Wrapper function to simulate framework_graphql test behavior.""" - - def execute(schema, query, *args, **kwargs): - from ariadne import graphql - - return graphql(schema, {"query": query}, *args, **kwargs) - - return execute - - -@dt_enabled -def test_query_and_mutation_async(app, graphql_run_async): - from graphql import __version__ as version - - FRAMEWORK_METRICS = [ - ("Python/Framework/Ariadne/None", 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/Ariadne/storage", 1), - ("GraphQL/resolve/Ariadne/storage_add", 1), - ("GraphQL/operation/Ariadne/query//storage", 1), - ("GraphQL/operation/Ariadne/mutation//storage_add.string", 1), - ] - _test_mutation_unscoped_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 2), - ("GraphQL/Ariadne/all", 2), - ("GraphQL/allOther", 2), - ("GraphQL/Ariadne/allOther", 2), - ] + _test_mutation_scoped_metrics - - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "StorageAdd", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String]", - } - - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @background_task() - def _test(): - async def coro(): - ok, response = await graphql_run_async(app, 'mutation { storage_add(string: "abc") { string } }') - assert ok and not response.get("errors") - ok, response = await graphql_run_async(app, "query { storage }") - assert ok and not response.get("errors") - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.get("data")) - assert "abc" in str(response.get("data")) - - loop = asyncio.new_event_loop() - loop.run_until_complete(coro()) - - _test() diff --git a/tests/framework_ariadne/test_asgi.py b/tests/framework_ariadne/test_asgi.py deleted file mode 100644 index 861f2aa93..000000000 --- a/tests/framework_ariadne/test_asgi.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -import pytest -from testing_support.asgi_testing import AsgiTest -from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_span_events import validate_span_events - - -@pytest.fixture(scope="session") -def graphql_asgi_run(): - """Wrapper function to simulate framework_graphql test behavior.""" - from _target_application import _target_asgi_application - - app = AsgiTest(_target_asgi_application) - - def execute(query): - return app.make_request( - "POST", "/", headers={"Content-Type": "application/json"}, body=json.dumps({"query": query}) - ) - - return execute - - -@dt_enabled -def test_query_and_mutation_asgi(graphql_asgi_run): - from graphql import __version__ as version - - FRAMEWORK_METRICS = [ - ("Python/Framework/Ariadne/None", 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/Ariadne/storage_add", 1), - ("GraphQL/operation/Ariadne/mutation//storage_add.string", 1), - ] - _test_query_scoped_metrics = [ - ("GraphQL/resolve/Ariadne/storage", 1), - ("GraphQL/operation/Ariadne/query//storage", 1), - ] - _test_unscoped_metrics = [ - ("WebTransaction", 1), - ("GraphQL/all", 1), - ("GraphQL/Ariadne/all", 1), - ("GraphQL/allWeb", 1), - ("GraphQL/Ariadne/allWeb", 1), - ] - _test_mutation_unscoped_metrics = _test_unscoped_metrics + _test_mutation_scoped_metrics - _test_query_unscoped_metrics = _test_unscoped_metrics + _test_query_scoped_metrics - - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "StorageAdd", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String]", - } - - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_query_scoped_metrics, - rollup_metrics=_test_query_unscoped_metrics + FRAMEWORK_METRICS, - ) - @validate_transaction_metrics( - "mutation//storage_add.string", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - index=-2, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes, index=-2) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes, index=-2) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - def _test(): - response = graphql_asgi_run('mutation { storage_add(string: "abc") { string } }') - assert response.status == 200 - response = json.loads(response.body.decode("utf-8")) - assert not response.get("errors") - - response = graphql_asgi_run("query { storage }") - assert response.status == 200 - response = json.loads(response.body.decode("utf-8")) - assert not response.get("errors") - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.get("data")) - assert "abc" in str(response.get("data")) - - _test() diff --git a/tests/framework_ariadne/test_wsgi.py b/tests/framework_ariadne/test_wsgi.py deleted file mode 100644 index 9ce2373d4..000000000 --- a/tests/framework_ariadne/test_wsgi.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import webtest -from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_span_events import validate_span_events - - -@pytest.fixture(scope="session") -def graphql_wsgi_run(): - """Wrapper function to simulate framework_graphql test behavior.""" - from _target_application import _target_wsgi_application - - app = webtest.TestApp(_target_wsgi_application) - - def execute(query): - return app.post_json("/", {"query": query}) - - return execute - - -@dt_enabled -def test_query_and_mutation_wsgi(graphql_wsgi_run): - from graphql import __version__ as version - - FRAMEWORK_METRICS = [ - ("Python/Framework/Ariadne/None", 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/Ariadne/storage_add", 1), - ("GraphQL/operation/Ariadne/mutation//storage_add.string", 1), - ] - _test_query_scoped_metrics = [ - ("GraphQL/resolve/Ariadne/storage", 1), - ("GraphQL/operation/Ariadne/query//storage", 1), - ] - _test_unscoped_metrics = [ - ("WebTransaction", 1), - ("Python/WSGI/Response", 1), - ("GraphQL/all", 1), - ("GraphQL/Ariadne/all", 1), - ("GraphQL/allWeb", 1), - ("GraphQL/Ariadne/allWeb", 1), - ] - _test_mutation_unscoped_metrics = _test_unscoped_metrics + _test_mutation_scoped_metrics - _test_query_unscoped_metrics = _test_unscoped_metrics + _test_query_scoped_metrics - - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "StorageAdd", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String]", - } - - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_query_scoped_metrics, - rollup_metrics=_test_query_unscoped_metrics + FRAMEWORK_METRICS, - ) - @validate_transaction_metrics( - "mutation//storage_add.string", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - index=-2, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes, index=-2) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes, index=-2) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - def _test(): - response = graphql_wsgi_run('mutation { storage_add(string: "abc") { string } }') - assert response.status_code == 200 - response = response.json_body - assert not response.get("errors") - - response = graphql_wsgi_run("query { storage }") - assert response.status_code == 200 - response = response.json_body - assert not response.get("errors") - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.get("data")) - assert "abc" in str(response.get("data")) - - _test() diff --git a/tests/framework_graphene/__init__.py b/tests/framework_graphene/__init__.py new file mode 100644 index 000000000..8030baccf --- /dev/null +++ b/tests/framework_graphene/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/framework_graphene/_target_application.py b/tests/framework_graphene/_target_application.py index 50acc776f..3f4b23e57 100644 --- a/tests/framework_graphene/_target_application.py +++ b/tests/framework_graphene/_target_application.py @@ -11,150 +11,45 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from graphene import Field, Int, List -from graphene import Mutation as GrapheneMutation -from graphene import NonNull, ObjectType, Schema, String, Union +from ._target_schema_async import target_schema as target_schema_async +from ._target_schema_sync import target_schema as target_schema_sync +from framework_graphene.test_application import GRAPHENE_VERSION -class Author(ObjectType): - first_name = String() - last_name = String() +def check_response(query, response): + if isinstance(query, str) and "error" not in query: + assert not response.errors, response + assert response.data + else: + assert response.errors, response -class Book(ObjectType): - id = Int() - name = String() - isbn = String() - author = Field(Author) - branch = String() +def run_sync(schema): + def _run_sync(query, middleware=None): + response = schema.execute(query, middleware=middleware) + check_response(query, response) + return response.data -class Magazine(ObjectType): - id = Int() - name = String() - issue = Int() - branch = String() + return _run_sync -class Item(Union): - class Meta: - types = (Book, Magazine) +def run_async(schema): + import asyncio + def _run_async(query, middleware=None): + loop = asyncio.get_event_loop() + response = loop.run_until_complete(schema.execute_async(query, middleware=middleware)) + check_response(query, response) -class Library(ObjectType): - id = Int() - branch = String() - magazine = Field(List(Magazine)) - book = Field(List(Book)) + return response.data + return _run_async -Storage = List(String) +target_application = { + "sync-sync": run_sync(target_schema_sync), + "async-sync": run_async(target_schema_sync), + "async-async": run_async(target_schema_async), + } -authors = [ - Author( - first_name="New", - last_name="Relic", - ), - Author( - first_name="Bob", - last_name="Smith", - ), - Author( - first_name="Leslie", - last_name="Jones", - ), -] - -books = [ - Book( - id=1, - name="Python Agent: The Book", - isbn="a-fake-isbn", - author=authors[0], - branch="riverside", - ), - Book( - id=2, - name="Ollies for O11y: A Sk8er's Guide to Observability", - isbn="a-second-fake-isbn", - author=authors[1], - branch="downtown", - ), - Book( - id=3, - name="[Redacted]", - isbn="a-third-fake-isbn", - author=authors[2], - branch="riverside", - ), -] - -magazines = [ - Magazine(id=1, name="Reli Updates Weekly", issue=1, branch="riverside"), - Magazine(id=2, name="Reli Updates Weekly", issue=2, branch="downtown"), - Magazine(id=3, name="Node Weekly", issue=1, branch="riverside"), -] - - -libraries = ["riverside", "downtown"] -libraries = [ - Library( - id=i + 1, - branch=branch, - magazine=[m for m in magazines if m.branch == branch], - book=[b for b in books if b.branch == branch], - ) - for i, branch in enumerate(libraries) -] - -storage = [] - - -class StorageAdd(GrapheneMutation): - class Arguments: - string = String(required=True) - - string = String() - - def mutate(self, info, string): - storage.append(string) - return String(string=string) - - -class Query(ObjectType): - library = Field(Library, index=Int(required=True)) - hello = String() - search = Field(List(Item), contains=String(required=True)) - echo = Field(String, echo=String(required=True)) - storage = Storage - error = String() - - def resolve_library(self, info, index): - return libraries[index] - - def resolve_storage(self, info): - return storage - - def resolve_search(self, info, contains): - search_books = [b for b in books if contains in b.name] - search_magazines = [m for m in magazines if contains in m.name] - return search_books + search_magazines - - def resolve_hello(self, info): - return "Hello!" - - def resolve_echo(self, info, echo): - return echo - - def resolve_error(self, info): - raise RuntimeError("Runtime Error!") - - error_non_null = Field(NonNull(String), resolver=resolve_error) - - -class Mutation(ObjectType): - storage_add = StorageAdd.Field() - - -_target_application = Schema(query=Query, mutation=Mutation, auto_camelcase=False) diff --git a/tests/framework_graphene/_target_schema_async.py b/tests/framework_graphene/_target_schema_async.py new file mode 100644 index 000000000..39905f2f9 --- /dev/null +++ b/tests/framework_graphene/_target_schema_async.py @@ -0,0 +1,72 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from graphene import Field, Int, List +from graphene import Mutation as GrapheneMutation +from graphene import NonNull, ObjectType, Schema, String, Union + +from ._target_schema_sync import Author, Book, Magazine, Item, Library, Storage, authors, books, magazines, libraries + + +storage = [] + + +async def resolve_library(self, info, index): + return libraries[index] + +async def resolve_storage(self, info): + return [storage.pop()] + +async def resolve_search(self, info, contains): + search_books = [b for b in books if contains in b.name] + search_magazines = [m for m in magazines if contains in m.name] + return search_books + search_magazines + +async def resolve_hello(self, info): + return "Hello!" + +async def resolve_echo(self, info, echo): + return echo + +async def resolve_error(self, info): + raise RuntimeError("Runtime Error!") + +async def resolve_storage_add(self, info, string): + storage.append(string) + return StorageAdd(string=string) + + +class StorageAdd(GrapheneMutation): + class Arguments: + string = String(required=True) + + string = String() + mutate = resolve_storage_add + + +class Query(ObjectType): + library = Field(Library, index=Int(required=True), resolver=resolve_library) + hello = String(resolver=resolve_hello) + search = Field(List(Item), contains=String(required=True), resolver=resolve_search) + echo = Field(String, echo=String(required=True), resolver=resolve_echo) + storage = Field(Storage, resolver=resolve_storage) + error = String(resolver=resolve_error) + error_non_null = Field(NonNull(String), resolver=resolve_error) + error_middleware = String(resolver=resolve_hello) + + +class Mutation(ObjectType): + storage_add = StorageAdd.Field() + + +target_schema = Schema(query=Query, mutation=Mutation, auto_camelcase=False) diff --git a/tests/framework_graphene/_target_schema_sync.py b/tests/framework_graphene/_target_schema_sync.py new file mode 100644 index 000000000..b59179065 --- /dev/null +++ b/tests/framework_graphene/_target_schema_sync.py @@ -0,0 +1,162 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from graphene import Field, Int, List +from graphene import Mutation as GrapheneMutation +from graphene import NonNull, ObjectType, Schema, String, Union + + +class Author(ObjectType): + first_name = String() + last_name = String() + + +class Book(ObjectType): + id = Int() + name = String() + isbn = String() + author = Field(Author) + branch = String() + + +class Magazine(ObjectType): + id = Int() + name = String() + issue = Int() + branch = String() + + +class Item(Union): + class Meta: + types = (Book, Magazine) + + +class Library(ObjectType): + id = Int() + branch = String() + magazine = Field(List(Magazine)) + book = Field(List(Book)) + + +Storage = List(String) + + +authors = [ + Author( + first_name="New", + last_name="Relic", + ), + Author( + first_name="Bob", + last_name="Smith", + ), + Author( + first_name="Leslie", + last_name="Jones", + ), +] + +books = [ + Book( + id=1, + name="Python Agent: The Book", + isbn="a-fake-isbn", + author=authors[0], + branch="riverside", + ), + Book( + id=2, + name="Ollies for O11y: A Sk8er's Guide to Observability", + isbn="a-second-fake-isbn", + author=authors[1], + branch="downtown", + ), + Book( + id=3, + name="[Redacted]", + isbn="a-third-fake-isbn", + author=authors[2], + branch="riverside", + ), +] + +magazines = [ + Magazine(id=1, name="Reli Updates Weekly", issue=1, branch="riverside"), + Magazine(id=2, name="Reli Updates Weekly", issue=2, branch="downtown"), + Magazine(id=3, name="Node Weekly", issue=1, branch="riverside"), +] + + +libraries = ["riverside", "downtown"] +libraries = [ + Library( + id=i + 1, + branch=branch, + magazine=[m for m in magazines if m.branch == branch], + book=[b for b in books if b.branch == branch], + ) + for i, branch in enumerate(libraries) +] + +storage = [] + + +def resolve_library(self, info, index): + return libraries[index] + +def resolve_storage(self, info): + return [storage.pop()] + +def resolve_search(self, info, contains): + search_books = [b for b in books if contains in b.name] + search_magazines = [m for m in magazines if contains in m.name] + return search_books + search_magazines + +def resolve_hello(self, info): + return "Hello!" + +def resolve_echo(self, info, echo): + return echo + +def resolve_error(self, info): + raise RuntimeError("Runtime Error!") + +def resolve_storage_add(self, info, string): + storage.append(string) + return StorageAdd(string=string) + + +class StorageAdd(GrapheneMutation): + class Arguments: + string = String(required=True) + + string = String() + mutate = resolve_storage_add + + +class Query(ObjectType): + library = Field(Library, index=Int(required=True), resolver=resolve_library) + hello = String(resolver=resolve_hello) + search = Field(List(Item), contains=String(required=True), resolver=resolve_search) + echo = Field(String, echo=String(required=True), resolver=resolve_echo) + storage = Field(Storage, resolver=resolve_storage) + error = String(resolver=resolve_error) + error_non_null = Field(NonNull(String), resolver=resolve_error) + error_middleware = String(resolver=resolve_hello) + + +class Mutation(ObjectType): + storage_add = StorageAdd.Field() + + +target_schema = Schema(query=Query, mutation=Mutation, auto_camelcase=False) diff --git a/tests/framework_graphene/test_application.py b/tests/framework_graphene/test_application.py index fd02d992a..838f3b515 100644 --- a/tests/framework_graphene/test_application.py +++ b/tests/framework_graphene/test_application.py @@ -13,518 +13,25 @@ # limitations under the License. import pytest -import six -from testing_support.fixtures import dt_enabled, override_application_settings -from testing_support.validators.validate_span_events import validate_span_events -from testing_support.validators.validate_transaction_count import ( - validate_transaction_count, -) -from testing_support.validators.validate_transaction_errors import ( - validate_transaction_errors, -) -from testing_support.validators.validate_transaction_metrics import ( - validate_transaction_metrics, -) -from newrelic.api.background_task import background_task -from newrelic.common.object_names import callable_name +from framework_graphql.test_application import * +from newrelic.common.package_version_utils import get_package_version +GRAPHENE_VERSION = get_package_version("graphene") -@pytest.fixture(scope="session") -def is_graphql_2(): - from graphql import __version__ as version - major_version = int(version.split(".")[0]) - return major_version == 2 +@pytest.fixture(scope="session", params=["sync-sync", "async-sync", "async-async"]) +def target_application(request): + from ._target_application import target_application + target_application = target_application.get(request.param, None) + if target_application is None: + pytest.skip("Unsupported combination.") + return -@pytest.fixture(scope="session") -def graphql_run(): - """Wrapper function to simulate framework_graphql test behavior.""" - - def execute(schema, *args, **kwargs): - return schema.execute(*args, **kwargs) - - return execute - - -def to_graphql_source(query): - def delay_import(): - try: - from graphql import Source - except ImportError: - # Fallback if Source is not implemented - return query - - from graphql import __version__ as version - - # For graphql2, Source objects aren't acceptable input - major_version = int(version.split(".")[0]) - if major_version == 2: - return query - - return Source(query) - - return delay_import - - -def example_middleware(next, root, info, **args): # pylint: disable=W0622 - return_value = next(root, info, **args) - return return_value - - -def error_middleware(next, root, info, **args): # pylint: disable=W0622 - raise RuntimeError("Runtime Error!") - - -_runtime_error_name = callable_name(RuntimeError) -_test_runtime_error = [(_runtime_error_name, "Runtime Error!")] -_graphql_base_rollup_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 1), - ("GraphQL/allOther", 1), - ("GraphQL/Graphene/all", 1), - ("GraphQL/Graphene/allOther", 1), -] - - -def test_basic(app, graphql_run): - from graphql import __version__ as version - - from newrelic.hooks.framework_graphene import framework_details - - FRAMEWORK_METRICS = [ - ("Python/Framework/Graphene/%s" % framework_details()[1], 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - - @validate_transaction_metrics( - "query//hello", - "GraphQL", - rollup_metrics=_graphql_base_rollup_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @background_task() - def _test(): - response = graphql_run(app, "{ hello }") - assert not response.errors - - _test() - - -@dt_enabled -def test_query_and_mutation(app, graphql_run): - from graphql import __version__ as version - - FRAMEWORK_METRICS = [ - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/Graphene/storage", 1), - ("GraphQL/resolve/Graphene/storage_add", 1), - ("GraphQL/operation/Graphene/query//storage", 1), - ("GraphQL/operation/Graphene/mutation//storage_add.string", 1), - ] - _test_mutation_unscoped_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 2), - ("GraphQL/Graphene/all", 2), - ("GraphQL/allOther", 2), - ("GraphQL/Graphene/allOther", 2), - ] + _test_mutation_scoped_metrics - - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "StorageAdd", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String]", - } - - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @background_task() - def _test(): - response = graphql_run(app, 'mutation { storage_add(string: "abc") { string } }') - assert not response.errors - response = graphql_run(app, "query { storage }") - assert not response.errors - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.data) - assert "abc" in str(response.data) - - _test() - - -@dt_enabled -def test_middleware(app, graphql_run, is_graphql_2): - _test_middleware_metrics = [ - ("GraphQL/operation/Graphene/query//hello", 1), - ("GraphQL/resolve/Graphene/hello", 1), - ("Function/test_application:example_middleware", 1), - ] - - @validate_transaction_metrics( - "query//hello", - "GraphQL", - scoped_metrics=_test_middleware_metrics, - rollup_metrics=_test_middleware_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - # Span count 5: Transaction, Operation, Middleware, and 1 Resolver and 1 Resolver Function - @validate_span_events(count=5) - @background_task() - def _test(): - response = graphql_run(app, "{ hello }", middleware=[example_middleware]) - assert not response.errors - assert "Hello!" in str(response.data) - - _test() - - -@dt_enabled -def test_exception_in_middleware(app, graphql_run): - query = "query MyQuery { hello }" - field = "hello" - - # Metrics - _test_exception_scoped_metrics = [ - ("GraphQL/operation/Graphene/query/MyQuery/%s" % field, 1), - ("GraphQL/resolve/Graphene/%s" % field, 1), - ] - _test_exception_rollup_metrics = [ - ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/test_application:error_middleware", 1), - ] + _test_exception_scoped_metrics - - # Attributes - _expected_exception_resolver_attributes = { - "graphql.field.name": field, - "graphql.field.parentType": "Query", - "graphql.field.path": field, - "graphql.field.returnType": "String", - } - _expected_exception_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "MyQuery", - "graphql.operation.query": query, - } - - @validate_transaction_metrics( - "test_application:error_middleware", - "GraphQL", - scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_exception_operation_attributes) - @validate_span_events(exact_agents=_expected_exception_resolver_attributes) - @validate_transaction_errors(errors=_test_runtime_error) - @background_task() - def _test(): - response = graphql_run(app, query, middleware=[error_middleware]) - assert response.errors - - _test() - - -@pytest.mark.parametrize("field", ("error", "error_non_null")) -@dt_enabled -def test_exception_in_resolver(app, graphql_run, field): - query = "query MyQuery { %s }" % field - - if six.PY2: - txn_name = "_target_application:resolve_error" - else: - txn_name = "_target_application:Query.resolve_error" - - # Metrics - _test_exception_scoped_metrics = [ - ("GraphQL/operation/Graphene/query/MyQuery/%s" % field, 1), - ("GraphQL/resolve/Graphene/%s" % field, 1), - ] - _test_exception_rollup_metrics = [ - ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/%s" % txn_name, 1), - ] + _test_exception_scoped_metrics - - # Attributes - _expected_exception_resolver_attributes = { - "graphql.field.name": field, - "graphql.field.parentType": "Query", - "graphql.field.path": field, - "graphql.field.returnType": "String!" if "non_null" in field else "String", - } - _expected_exception_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "MyQuery", - "graphql.operation.query": query, - } - - @validate_transaction_metrics( - txn_name, - "GraphQL", - scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_exception_operation_attributes) - @validate_span_events(exact_agents=_expected_exception_resolver_attributes) - @validate_transaction_errors(errors=_test_runtime_error) - @background_task() - def _test(): - response = graphql_run(app, query) - assert response.errors - - _test() - - -@dt_enabled -@pytest.mark.parametrize( - "query,exc_class", - [ - ("query MyQuery { missing_field }", "GraphQLError"), - ("{ syntax_error ", "graphql.error.syntax_error:GraphQLSyntaxError"), - ], -) -def test_exception_in_validation(app, graphql_run, is_graphql_2, query, exc_class): - if "syntax" in query: - txn_name = "graphql.language.parser:parse" - else: - if is_graphql_2: - txn_name = "graphql.validation.validation:validate" - else: - txn_name = "graphql.validation.validate:validate" - - # Import path differs between versions - if exc_class == "GraphQLError": - from graphql.error import GraphQLError - - exc_class = callable_name(GraphQLError) - - _test_exception_scoped_metrics = [ - ("GraphQL/operation/Graphene///", 1), - ] - _test_exception_rollup_metrics = [ - ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/%s" % txn_name, 1), - ] + _test_exception_scoped_metrics - - # Attributes - _expected_exception_operation_attributes = { - "graphql.operation.type": "", - "graphql.operation.name": "", - "graphql.operation.query": query, - } - - @validate_transaction_metrics( - txn_name, - "GraphQL", - scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_exception_operation_attributes) - @validate_transaction_errors(errors=[exc_class]) - @background_task() - def _test(): - response = graphql_run(app, query) - assert response.errors - - _test() - - -@dt_enabled -def test_operation_metrics_and_attrs(app, graphql_run): - operation_metrics = [("GraphQL/operation/Graphene/query/MyQuery/library", 1)] - operation_attrs = { - "graphql.operation.type": "query", - "graphql.operation.name": "MyQuery", - } - - @validate_transaction_metrics( - "query/MyQuery/library", - "GraphQL", - scoped_metrics=operation_metrics, - rollup_metrics=operation_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - # Span count 16: Transaction, Operation, and 7 Resolvers and Resolver functions - # library, library.name, library.book - # library.book.name and library.book.id for each book resolved (in this case 2) - @validate_span_events(count=16) - @validate_span_events(exact_agents=operation_attrs) - @background_task() - def _test(): - response = graphql_run(app, "query MyQuery { library(index: 0) { branch, book { id, name } } }") - assert not response.errors - - _test() - - -@dt_enabled -def test_field_resolver_metrics_and_attrs(app, graphql_run): - field_resolver_metrics = [("GraphQL/resolve/Graphene/hello", 1)] - graphql_attrs = { - "graphql.field.name": "hello", - "graphql.field.parentType": "Query", - "graphql.field.path": "hello", - "graphql.field.returnType": "String", - } - - @validate_transaction_metrics( - "query//hello", - "GraphQL", - scoped_metrics=field_resolver_metrics, - rollup_metrics=field_resolver_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - # Span count 4: Transaction, Operation, and 1 Resolver and Resolver function - @validate_span_events(count=4) - @validate_span_events(exact_agents=graphql_attrs) - @background_task() - def _test(): - response = graphql_run(app, "{ hello }") - assert not response.errors - assert "Hello!" in str(response.data) - - _test() - - -_test_queries = [ - ("{ hello }", "{ hello }"), # Basic query extraction - ("{ error }", "{ error }"), # Extract query on field error - (to_graphql_source("{ hello }"), "{ hello }"), # Extract query from Source objects - ("{ library(index: 0) { branch } }", "{ library(index: ?) { branch } }"), # Integers - ('{ echo(echo: "123") }', "{ echo(echo: ?) }"), # Strings with numerics - ('{ echo(echo: "test") }', "{ echo(echo: ?) }"), # Strings - ('{ TestEcho: echo(echo: "test") }', "{ TestEcho: echo(echo: ?) }"), # Aliases - ('{ TestEcho: echo(echo: "test") }', "{ TestEcho: echo(echo: ?) }"), # Variables - ( # Fragments - '{ ...MyFragment } fragment MyFragment on Query { echo(echo: "test") }', - "{ ...MyFragment } fragment MyFragment on Query { echo(echo: ?) }", - ), -] - - -@dt_enabled -@pytest.mark.parametrize("query,obfuscated", _test_queries) -def test_query_obfuscation(app, graphql_run, query, obfuscated): - graphql_attrs = {"graphql.operation.query": obfuscated} - - if callable(query): - query = query() - - @validate_span_events(exact_agents=graphql_attrs) - @background_task() - def _test(): - response = graphql_run(app, query) - if not isinstance(query, str) or "error" not in query: - assert not response.errors - - _test() - - -_test_queries = [ - ("{ hello }", "/hello"), # Basic query - ("{ error }", "/error"), # Extract deepest path on field error - ('{ echo(echo: "test") }', "/echo"), # Fields with arguments - ( - "{ library(index: 0) { branch, book { isbn branch } } }", - "/library", - ), # Complex Example, 1 level - ( - "{ library(index: 0) { book { author { first_name }} } }", - "/library.book.author.first_name", - ), # Complex Example, 2 levels - ("{ library(index: 0) { id, book { name } } }", "/library.book.name"), # Filtering - ('{ TestEcho: echo(echo: "test") }', "/echo"), # Aliases - ( - '{ search(contains: "A") { __typename ... on Book { name } } }', - "/search.name", - ), # InlineFragment - ( - '{ hello echo(echo: "test") }', - "", - ), # Multiple root selections. (need to decide on final behavior) - # FragmentSpread - ( - "{ library(index: 0) { book { ...MyFragment } } } fragment MyFragment on Book { name id }", # Fragment filtering - "/library.book.name", - ), - ( - "{ library(index: 0) { book { ...MyFragment } } } fragment MyFragment on Book { author { first_name } }", - "/library.book.author.first_name", - ), - ( - "{ library(index: 0) { book { ...MyFragment } magazine { ...MagFragment } } } fragment MyFragment on Book { author { first_name } } fragment MagFragment on Magazine { name }", - "/library", - ), -] - - -@dt_enabled -@pytest.mark.parametrize("query,expected_path", _test_queries) -def test_deepest_unique_path(app, graphql_run, query, expected_path): - if expected_path == "/error": - if six.PY2: - txn_name = "_target_application:resolve_error" - else: - txn_name = "_target_application:Query.resolve_error" - else: - txn_name = "query/%s" % expected_path - - @validate_transaction_metrics( - txn_name, - "GraphQL", - background_task=True, - ) - @background_task() - def _test(): - response = graphql_run(app, query) - if "error" not in query: - assert not response.errors - - _test() - - -@pytest.mark.parametrize("capture_introspection_setting", (True, False)) -def test_introspection_transactions(app, graphql_run, capture_introspection_setting): - txn_ct = 1 if capture_introspection_setting else 0 - - @override_application_settings( - {"instrumentation.graphql.capture_introspection_queries": capture_introspection_setting} - ) - @validate_transaction_count(txn_ct) - @background_task() - def _test(): - response = graphql_run(app, "{ __schema { types { name } } }") - assert not response.errors - - _test() + param = request.param.split("-") + is_background = param[0] not in {"wsgi", "asgi"} + schema_type = param[1] + extra_spans = 4 if param[0] == "wsgi" else 0 + assert GRAPHENE_VERSION is not None + return "Graphene", GRAPHENE_VERSION, target_application, is_background, schema_type, extra_spans diff --git a/tests/framework_graphql/__init__.py b/tests/framework_graphql/__init__.py new file mode 100644 index 000000000..8030baccf --- /dev/null +++ b/tests/framework_graphql/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/framework_graphql/_target_application.py b/tests/framework_graphql/_target_application.py index 7bef5e975..91da5d767 100644 --- a/tests/framework_graphql/_target_application.py +++ b/tests/framework_graphql/_target_application.py @@ -12,228 +12,55 @@ # See the License for the specific language governing permissions and # limitations under the License. -from graphql import ( - GraphQLArgument, - GraphQLField, - GraphQLInt, - GraphQLList, - GraphQLNonNull, - GraphQLObjectType, - GraphQLSchema, - GraphQLString, - GraphQLUnionType, -) - -authors = [ - { - "first_name": "New", - "last_name": "Relic", - }, - { - "first_name": "Bob", - "last_name": "Smith", - }, - { - "first_name": "Leslie", - "last_name": "Jones", - }, -] - -books = [ - { - "id": 1, - "name": "Python Agent: The Book", - "isbn": "a-fake-isbn", - "author": authors[0], - "branch": "riverside", - }, - { - "id": 2, - "name": "Ollies for O11y: A Sk8er's Guide to Observability", - "isbn": "a-second-fake-isbn", - "author": authors[1], - "branch": "downtown", - }, - { - "id": 3, - "name": "[Redacted]", - "isbn": "a-third-fake-isbn", - "author": authors[2], - "branch": "riverside", - }, -] - -magazines = [ - {"id": 1, "name": "Reli Updates Weekly", "issue": 1, "branch": "riverside"}, - {"id": 2, "name": "Reli Updates Weekly", "issue": 2, "branch": "downtown"}, - {"id": 3, "name": "Node Weekly", "issue": 1, "branch": "riverside"}, -] - - -libraries = ["riverside", "downtown"] -libraries = [ - { - "id": i + 1, - "branch": branch, - "magazine": [m for m in magazines if m["branch"] == branch], - "book": [b for b in books if b["branch"] == branch], - } - for i, branch in enumerate(libraries) -] - -storage = [] - - -def resolve_library(parent, info, index): - return libraries[index] - - -def resolve_storage_add(parent, info, string): - storage.append(string) - return string - - -def resolve_storage(parent, info): - return storage - - -def resolve_search(parent, info, contains): - search_books = [b for b in books if contains in b["name"]] - search_magazines = [m for m in magazines if contains in m["name"]] - return search_books + search_magazines - - -Author = GraphQLObjectType( - "Author", - { - "first_name": GraphQLField(GraphQLString), - "last_name": GraphQLField(GraphQLString), - }, -) - -Book = GraphQLObjectType( - "Book", - { - "id": GraphQLField(GraphQLInt), - "name": GraphQLField(GraphQLString), - "isbn": GraphQLField(GraphQLString), - "author": GraphQLField(Author), - "branch": GraphQLField(GraphQLString), - }, -) - -Magazine = GraphQLObjectType( - "Magazine", - { - "id": GraphQLField(GraphQLInt), - "name": GraphQLField(GraphQLString), - "issue": GraphQLField(GraphQLInt), - "branch": GraphQLField(GraphQLString), - }, -) - - -Library = GraphQLObjectType( - "Library", - { - "id": GraphQLField(GraphQLInt), - "branch": GraphQLField(GraphQLString), - "book": GraphQLField(GraphQLList(Book)), - "magazine": GraphQLField(GraphQLList(Magazine)), - }, -) - -Storage = GraphQLList(GraphQLString) - - -def resolve_hello(root, info): - return "Hello!" - - -def resolve_echo(root, info, echo): - return echo - - -def resolve_error(root, info): - raise RuntimeError("Runtime Error!") - - -try: - hello_field = GraphQLField(GraphQLString, resolver=resolve_hello) - library_field = GraphQLField( - Library, - resolver=resolve_library, - args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, - ) - search_field = GraphQLField( - GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), - args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - echo_field = GraphQLField( - GraphQLString, - resolver=resolve_echo, - args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - storage_field = GraphQLField( - Storage, - resolver=resolve_storage, - ) - storage_add_field = GraphQLField( - Storage, - resolver=resolve_storage_add, - args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - error_field = GraphQLField(GraphQLString, resolver=resolve_error) - error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolver=resolve_error) - error_middleware_field = GraphQLField(GraphQLString, resolver=resolve_hello) -except TypeError: - hello_field = GraphQLField(GraphQLString, resolve=resolve_hello) - library_field = GraphQLField( - Library, - resolve=resolve_library, - args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, - ) - search_field = GraphQLField( - GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), - args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - echo_field = GraphQLField( - GraphQLString, - resolve=resolve_echo, - args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - storage_field = GraphQLField( - Storage, - resolve=resolve_storage, - ) - storage_add_field = GraphQLField( - GraphQLString, - resolve=resolve_storage_add, - args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, - ) - error_field = GraphQLField(GraphQLString, resolve=resolve_error) - error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolve=resolve_error) - error_middleware_field = GraphQLField(GraphQLString, resolve=resolve_hello) - -query = GraphQLObjectType( - name="Query", - fields={ - "hello": hello_field, - "library": library_field, - "search": search_field, - "echo": echo_field, - "storage": storage_field, - "error": error_field, - "error_non_null": error_non_null_field, - "error_middleware": error_middleware_field, - }, -) - -mutation = GraphQLObjectType( - name="Mutation", - fields={ - "storage_add": storage_add_field, - }, -) - -_target_application = GraphQLSchema(query=query, mutation=mutation) +from graphql.language.source import Source + +from ._target_schema_async import target_schema as target_schema_async +from ._target_schema_sync import target_schema as target_schema_sync + + +def check_response(query, response): + if isinstance(query, str) and "error" not in query or isinstance(query, Source) and "error" not in query.body: + assert not response.errors, response.errors + assert response.data + else: + assert response.errors + + +def run_sync(schema): + def _run_sync(query, middleware=None): + try: + from graphql import graphql_sync as graphql + except ImportError: + from graphql import graphql + + response = graphql(schema, query, middleware=middleware) + + check_response(query, response) + + return response.data + + return _run_sync + + +def run_async(schema): + import asyncio + + from graphql import graphql + + def _run_async(query, middleware=None): + coro = graphql(schema, query, middleware=middleware) + loop = asyncio.get_event_loop() + response = loop.run_until_complete(coro) + + check_response(query, response) + + return response.data + + return _run_async + + +target_application = { + "sync-sync": run_sync(target_schema_sync), + "async-sync": run_async(target_schema_sync), + "async-async": run_async(target_schema_async), +} diff --git a/tests/framework_graphql/_target_schema_async.py b/tests/framework_graphql/_target_schema_async.py new file mode 100644 index 000000000..aad4eb271 --- /dev/null +++ b/tests/framework_graphql/_target_schema_async.py @@ -0,0 +1,155 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from graphql import ( + GraphQLArgument, + GraphQLField, + GraphQLInt, + GraphQLList, + GraphQLNonNull, + GraphQLObjectType, + GraphQLSchema, + GraphQLString, + GraphQLUnionType, +) + +from ._target_schema_sync import books, libraries, magazines + +storage = [] + + +async def resolve_library(parent, info, index): + return libraries[index] + + +async def resolve_storage_add(parent, info, string): + storage.append(string) + return string + + +async def resolve_storage(parent, info): + return [storage.pop()] + + +async def resolve_search(parent, info, contains): + search_books = [b for b in books if contains in b["name"]] + search_magazines = [m for m in magazines if contains in m["name"]] + return search_books + search_magazines + + +Author = GraphQLObjectType( + "Author", + { + "first_name": GraphQLField(GraphQLString), + "last_name": GraphQLField(GraphQLString), + }, +) + +Book = GraphQLObjectType( + "Book", + { + "id": GraphQLField(GraphQLInt), + "name": GraphQLField(GraphQLString), + "isbn": GraphQLField(GraphQLString), + "author": GraphQLField(Author), + "branch": GraphQLField(GraphQLString), + }, +) + +Magazine = GraphQLObjectType( + "Magazine", + { + "id": GraphQLField(GraphQLInt), + "name": GraphQLField(GraphQLString), + "issue": GraphQLField(GraphQLInt), + "branch": GraphQLField(GraphQLString), + }, +) + + +Library = GraphQLObjectType( + "Library", + { + "id": GraphQLField(GraphQLInt), + "branch": GraphQLField(GraphQLString), + "book": GraphQLField(GraphQLList(Book)), + "magazine": GraphQLField(GraphQLList(Magazine)), + }, +) + +Storage = GraphQLList(GraphQLString) + + +async def resolve_hello(root, info): + return "Hello!" + + +async def resolve_echo(root, info, echo): + return echo + + +async def resolve_error(root, info): + raise RuntimeError("Runtime Error!") + + +hello_field = GraphQLField(GraphQLString, resolve=resolve_hello) +library_field = GraphQLField( + Library, + resolve=resolve_library, + args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, +) +search_field = GraphQLField( + GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), + args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +echo_field = GraphQLField( + GraphQLString, + resolve=resolve_echo, + args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +storage_field = GraphQLField( + Storage, + resolve=resolve_storage, +) +storage_add_field = GraphQLField( + GraphQLString, + resolve=resolve_storage_add, + args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +error_field = GraphQLField(GraphQLString, resolve=resolve_error) +error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolve=resolve_error) +error_middleware_field = GraphQLField(GraphQLString, resolve=resolve_hello) + +query = GraphQLObjectType( + name="Query", + fields={ + "hello": hello_field, + "library": library_field, + "search": search_field, + "echo": echo_field, + "storage": storage_field, + "error": error_field, + "error_non_null": error_non_null_field, + "error_middleware": error_middleware_field, + }, +) + +mutation = GraphQLObjectType( + name="Mutation", + fields={ + "storage_add": storage_add_field, + }, +) + +target_schema = GraphQLSchema(query=query, mutation=mutation) diff --git a/tests/framework_graphql/_target_schema_sync.py b/tests/framework_graphql/_target_schema_sync.py new file mode 100644 index 000000000..302a6c66e --- /dev/null +++ b/tests/framework_graphql/_target_schema_sync.py @@ -0,0 +1,210 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applic`ab`le law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from graphql import ( + GraphQLArgument, + GraphQLField, + GraphQLInt, + GraphQLList, + GraphQLNonNull, + GraphQLObjectType, + GraphQLSchema, + GraphQLString, + GraphQLUnionType, +) + +authors = [ + { + "first_name": "New", + "last_name": "Relic", + }, + { + "first_name": "Bob", + "last_name": "Smith", + }, + { + "first_name": "Leslie", + "last_name": "Jones", + }, +] + +books = [ + { + "id": 1, + "name": "Python Agent: The Book", + "isbn": "a-fake-isbn", + "author": authors[0], + "branch": "riverside", + }, + { + "id": 2, + "name": "Ollies for O11y: A Sk8er's Guide to Observability", + "isbn": "a-second-fake-isbn", + "author": authors[1], + "branch": "downtown", + }, + { + "id": 3, + "name": "[Redacted]", + "isbn": "a-third-fake-isbn", + "author": authors[2], + "branch": "riverside", + }, +] + +magazines = [ + {"id": 1, "name": "Reli Updates Weekly", "issue": 1, "branch": "riverside"}, + {"id": 2, "name": "Reli Updates Weekly", "issue": 2, "branch": "downtown"}, + {"id": 3, "name": "Node Weekly", "issue": 1, "branch": "riverside"}, +] + + +libraries = ["riverside", "downtown"] +libraries = [ + { + "id": i + 1, + "branch": branch, + "magazine": [m for m in magazines if m["branch"] == branch], + "book": [b for b in books if b["branch"] == branch], + } + for i, branch in enumerate(libraries) +] + +storage = [] + + +def resolve_library(parent, info, index): + return libraries[index] + + +def resolve_storage_add(parent, info, string): + storage.append(string) + return string + + +def resolve_storage(parent, info): + return [storage.pop()] + + +def resolve_search(parent, info, contains): + search_books = [b for b in books if contains in b["name"]] + search_magazines = [m for m in magazines if contains in m["name"]] + return search_books + search_magazines + + +Author = GraphQLObjectType( + "Author", + { + "first_name": GraphQLField(GraphQLString), + "last_name": GraphQLField(GraphQLString), + }, +) + +Book = GraphQLObjectType( + "Book", + { + "id": GraphQLField(GraphQLInt), + "name": GraphQLField(GraphQLString), + "isbn": GraphQLField(GraphQLString), + "author": GraphQLField(Author), + "branch": GraphQLField(GraphQLString), + }, +) + +Magazine = GraphQLObjectType( + "Magazine", + { + "id": GraphQLField(GraphQLInt), + "name": GraphQLField(GraphQLString), + "issue": GraphQLField(GraphQLInt), + "branch": GraphQLField(GraphQLString), + }, +) + + +Library = GraphQLObjectType( + "Library", + { + "id": GraphQLField(GraphQLInt), + "branch": GraphQLField(GraphQLString), + "book": GraphQLField(GraphQLList(Book)), + "magazine": GraphQLField(GraphQLList(Magazine)), + }, +) + +Storage = GraphQLList(GraphQLString) + + +def resolve_hello(root, info): + return "Hello!" + + +def resolve_echo(root, info, echo): + return echo + + +def resolve_error(root, info): + raise RuntimeError("Runtime Error!") + + +hello_field = GraphQLField(GraphQLString, resolve=resolve_hello) +library_field = GraphQLField( + Library, + resolve=resolve_library, + args={"index": GraphQLArgument(GraphQLNonNull(GraphQLInt))}, +) +search_field = GraphQLField( + GraphQLList(GraphQLUnionType("Item", (Book, Magazine), resolve_type=resolve_search)), + args={"contains": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +echo_field = GraphQLField( + GraphQLString, + resolve=resolve_echo, + args={"echo": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +storage_field = GraphQLField( + Storage, + resolve=resolve_storage, +) +storage_add_field = GraphQLField( + GraphQLString, + resolve=resolve_storage_add, + args={"string": GraphQLArgument(GraphQLNonNull(GraphQLString))}, +) +error_field = GraphQLField(GraphQLString, resolve=resolve_error) +error_non_null_field = GraphQLField(GraphQLNonNull(GraphQLString), resolve=resolve_error) +error_middleware_field = GraphQLField(GraphQLString, resolve=resolve_hello) + +query = GraphQLObjectType( + name="Query", + fields={ + "hello": hello_field, + "library": library_field, + "search": search_field, + "echo": echo_field, + "storage": storage_field, + "error": error_field, + "error_non_null": error_non_null_field, + "error_middleware": error_middleware_field, + }, +) + +mutation = GraphQLObjectType( + name="Mutation", + fields={ + "storage_add": storage_add_field, + }, +) + +target_schema = GraphQLSchema(query=query, mutation=mutation) diff --git a/tests/framework_graphql/conftest.py b/tests/framework_graphql/conftest.py index 4d9e06758..5302da2b8 100644 --- a/tests/framework_graphql/conftest.py +++ b/tests/framework_graphql/conftest.py @@ -13,10 +13,12 @@ # limitations under the License. import pytest -import six - -from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture # noqa: F401; pylint: disable=W0611 +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) +from newrelic.packages import six _default_settings = { "transaction_tracer.explain_threshold": 0.0, @@ -32,11 +34,16 @@ ) -@pytest.fixture(scope="session") -def app(): - from _target_application import _target_application +@pytest.fixture(scope="session", params=["sync-sync", "async-sync", "async-async"]) +def target_application(request): + from ._target_application import target_application + + app = target_application.get(request.param, None) + if app is None: + pytest.skip("Unsupported combination.") + return - return _target_application + return "GraphQL", None, app, True, request.param.split("-")[1], 0 if six.PY2: diff --git a/tests/framework_graphql/test_application.py b/tests/framework_graphql/test_application.py index dd49ee37f..b5d78699d 100644 --- a/tests/framework_graphql/test_application.py +++ b/tests/framework_graphql/test_application.py @@ -13,6 +13,10 @@ # limitations under the License. import pytest +from framework_graphql.test_application_async import ( + error_middleware_async, + example_middleware_async, +) from testing_support.fixtures import dt_enabled, override_application_settings from testing_support.validators.validate_code_level_metrics import ( validate_code_level_metrics, @@ -30,24 +34,18 @@ from newrelic.api.background_task import background_task from newrelic.common.object_names import callable_name +from newrelic.common.package_version_utils import get_package_version -@pytest.fixture(scope="session") -def is_graphql_2(): - from graphql import __version__ as version - - major_version = int(version.split(".")[0]) - return major_version == 2 +graphql_version = get_package_version("graphql-core") +def conditional_decorator(decorator, condition): + def _conditional_decorator(func): + if not condition: + return func + return decorator(func) -@pytest.fixture(scope="session") -def graphql_run(): - try: - from graphql import graphql_sync as graphql - except ImportError: - from graphql import graphql - - return graphql + return _conditional_decorator def to_graphql_source(query): @@ -58,13 +56,6 @@ def delay_import(): # Fallback if Source is not implemented return query - from graphql import __version__ as version - - # For graphql2, Source objects aren't acceptable input - major_version = int(version.split(".")[0]) - if major_version == 2: - return query - return Source(query) return delay_import @@ -79,66 +70,86 @@ def error_middleware(next, root, info, **args): raise RuntimeError("Runtime Error!") -def test_no_harm_no_transaction(app, graphql_run): +def test_no_harm_no_transaction(target_application): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application + def _test(): - response = graphql_run(app, "{ __schema { types { name } } }") - assert not response.errors + response = target_application("{ __schema { types { name } } }") _test() +example_middleware = [example_middleware] +error_middleware = [error_middleware] + +example_middleware.append(example_middleware_async) +error_middleware.append(error_middleware_async) + _runtime_error_name = callable_name(RuntimeError) _test_runtime_error = [(_runtime_error_name, "Runtime Error!")] -_graphql_base_rollup_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 1), - ("GraphQL/allOther", 1), - ("GraphQL/GraphQL/all", 1), - ("GraphQL/GraphQL/allOther", 1), -] -def test_basic(app, graphql_run): - from graphql import __version__ as version +def _graphql_base_rollup_metrics(framework, version, background_task=True): + graphql_version = get_package_version("graphql-core") - FRAMEWORK_METRICS = [ - ("Python/Framework/GraphQL/%s" % version, 1), + metrics = [ + ("Python/Framework/GraphQL/%s" % graphql_version, 1), + ("GraphQL/all", 1), + ("GraphQL/%s/all" % framework, 1), ] + if background_task: + metrics.extend( + [ + ("GraphQL/allOther", 1), + ("GraphQL/%s/allOther" % framework, 1), + ] + ) + else: + metrics.extend( + [ + ("GraphQL/allWeb", 1), + ("GraphQL/%s/allWeb" % framework, 1), + ] + ) + + if framework != "GraphQL": + metrics.append(("Python/Framework/%s/%s" % (framework, version), 1)) + + return metrics + + +def test_basic(target_application): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application @validate_transaction_metrics( "query//hello", "GraphQL", - rollup_metrics=_graphql_base_rollup_metrics + FRAMEWORK_METRICS, - background_task=True, + rollup_metrics=_graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, ) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, "{ hello }") - assert not response.errors + response = target_application("{ hello }") + assert response["hello"] == "Hello!" _test() @dt_enabled -def test_query_and_mutation(app, graphql_run, is_graphql_2): - from graphql import __version__ as version +def test_query_and_mutation(target_application): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application + + mutation_path = "storage_add" if framework != "Graphene" else "storage_add.string" + type_annotation = "!" if framework == "Strawberry" else "" - FRAMEWORK_METRICS = [ - ("Python/Framework/GraphQL/%s" % version, 1), - ] _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/GraphQL/storage", 1), - ("GraphQL/resolve/GraphQL/storage_add", 1), - ("GraphQL/operation/GraphQL/query//storage", 1), - ("GraphQL/operation/GraphQL/mutation//storage_add", 1), + ("GraphQL/resolve/%s/storage_add" % framework, 1), + ("GraphQL/operation/%s/mutation//%s" % (framework, mutation_path), 1), + ] + _test_query_scoped_metrics = [ + ("GraphQL/resolve/%s/storage" % framework, 1), + ("GraphQL/operation/%s/query//storage" % framework, 1), ] - _test_mutation_unscoped_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 2), - ("GraphQL/GraphQL/all", 2), - ("GraphQL/allOther", 2), - ("GraphQL/GraphQL/allOther", 2), - ] + _test_mutation_scoped_metrics _expected_mutation_operation_attributes = { "graphql.operation.type": "mutation", @@ -148,7 +159,7 @@ def test_query_and_mutation(app, graphql_run, is_graphql_2): "graphql.field.name": "storage_add", "graphql.field.parentType": "Mutation", "graphql.field.path": "storage_add", - "graphql.field.returnType": "[String]" if is_graphql_2 else "String", + "graphql.field.returnType": ("String" if framework != "Graphene" else "StorageAdd") + type_annotation, } _expected_query_operation_attributes = { "graphql.operation.type": "query", @@ -158,78 +169,108 @@ def test_query_and_mutation(app, graphql_run, is_graphql_2): "graphql.field.name": "storage", "graphql.field.parentType": "Query", "graphql.field.path": "storage", - "graphql.field.returnType": "[String]", + "graphql.field.returnType": "[String%s]%s" % (type_annotation, type_annotation), } - @validate_code_level_metrics("_target_application", "resolve_storage") - @validate_code_level_metrics("_target_application", "resolve_storage_add") + @validate_code_level_metrics( + "framework_%s._target_schema_%s" % (framework.lower(), schema_type), "resolve_storage_add" + ) + @validate_span_events(exact_agents=_expected_mutation_operation_attributes) + @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) @validate_transaction_metrics( - "query//storage", + "mutation//%s" % mutation_path, "GraphQL", scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - background_task=True, + rollup_metrics=_test_mutation_scoped_metrics + _graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) + @conditional_decorator(background_task(), is_bg) + def _mutation(): + if framework == "Graphene": + query = 'mutation { storage_add(string: "abc") { string } }' + else: + query = 'mutation { storage_add(string: "abc") }' + response = target_application(query) + assert response["storage_add"] == "abc" or response["storage_add"]["string"] == "abc" + + @validate_code_level_metrics("framework_%s._target_schema_%s" % (framework.lower(), schema_type), "resolve_storage") @validate_span_events(exact_agents=_expected_query_operation_attributes) @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @background_task() - def _test(): - response = graphql_run(app, 'mutation { storage_add(string: "abc") }') - assert not response.errors - response = graphql_run(app, "query { storage }") - assert not response.errors - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.data) - assert "abc" in str(response.data) + @validate_transaction_metrics( + "query//storage", + "GraphQL", + scoped_metrics=_test_query_scoped_metrics, + rollup_metrics=_test_query_scoped_metrics + _graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, + ) + @conditional_decorator(background_task(), is_bg) + def _query(): + response = target_application("query { storage }") + assert response["storage"] == ["abc"] - _test() + _mutation() + _query() +@pytest.mark.parametrize("middleware", example_middleware) @dt_enabled -def test_middleware(app, graphql_run, is_graphql_2): +def test_middleware(target_application, middleware): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application + + name = "%s:%s" % (middleware.__module__, middleware.__name__) + if "async" in name: + if schema_type != "async": + pytest.skip("Async middleware not supported in sync applications.") + _test_middleware_metrics = [ - ("GraphQL/operation/GraphQL/query//hello", 1), - ("GraphQL/resolve/GraphQL/hello", 1), - ("Function/test_application:example_middleware", 1), + ("GraphQL/operation/%s/query//hello" % framework, 1), + ("GraphQL/resolve/%s/hello" % framework, 1), + ("Function/%s" % name, 1), ] - @validate_code_level_metrics("test_application", "example_middleware") - @validate_code_level_metrics("_target_application", "resolve_hello") + # Span count 5: Transaction, Operation, Middleware, and 1 Resolver and Resolver Function + span_count = 5 + extra_spans + + @validate_code_level_metrics(*name.split(":")) + @validate_code_level_metrics("framework_%s._target_schema_%s" % (framework.lower(), schema_type), "resolve_hello") + @validate_span_events(count=span_count) @validate_transaction_metrics( "query//hello", "GraphQL", scoped_metrics=_test_middleware_metrics, - rollup_metrics=_test_middleware_metrics + _graphql_base_rollup_metrics, - background_task=True, + rollup_metrics=_test_middleware_metrics + _graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, ) - # Span count 5: Transaction, Operation, Middleware, and 1 Resolver and Resolver Function - @validate_span_events(count=5) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, "{ hello }", middleware=[example_middleware]) - assert not response.errors - assert "Hello!" in str(response.data) + response = target_application("{ hello }", middleware=[middleware]) + assert response["hello"] == "Hello!" _test() +@pytest.mark.parametrize("middleware", error_middleware) @dt_enabled -def test_exception_in_middleware(app, graphql_run): - query = "query MyQuery { hello }" - field = "hello" +def test_exception_in_middleware(target_application, middleware): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application + query = "query MyQuery { error_middleware }" + field = "error_middleware" + + name = "%s:%s" % (middleware.__module__, middleware.__name__) + if "async" in name: + if schema_type != "async": + pytest.skip("Async middleware not supported in sync applications.") # Metrics _test_exception_scoped_metrics = [ - ("GraphQL/operation/GraphQL/query/MyQuery/%s" % field, 1), - ("GraphQL/resolve/GraphQL/%s" % field, 1), + ("GraphQL/operation/%s/query/MyQuery/%s" % (framework, field), 1), + ("GraphQL/resolve/%s/%s" % (framework, field), 1), + ("Function/%s" % name, 1), ] _test_exception_rollup_metrics = [ ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/test_application:error_middleware", 1), + ("Errors/all%s" % ("Other" if is_bg else "Web"), 1), + ("Errors/%sTransaction/GraphQL/%s" % ("Other" if is_bg else "Web", name), 1), ] + _test_exception_scoped_metrics # Attributes @@ -246,39 +287,39 @@ def test_exception_in_middleware(app, graphql_run): } @validate_transaction_metrics( - "test_application:error_middleware", + name, "GraphQL", scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, + rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, ) @validate_span_events(exact_agents=_expected_exception_operation_attributes) @validate_span_events(exact_agents=_expected_exception_resolver_attributes) @validate_transaction_errors(errors=_test_runtime_error) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, query, middleware=[error_middleware]) - assert response.errors + response = target_application(query, middleware=[middleware]) _test() @pytest.mark.parametrize("field", ("error", "error_non_null")) @dt_enabled -def test_exception_in_resolver(app, graphql_run, field): +def test_exception_in_resolver(target_application, field): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application query = "query MyQuery { %s }" % field - txn_name = "_target_application:resolve_error" + txn_name = "framework_%s._target_schema_%s:resolve_error" % (framework.lower(), schema_type) # Metrics _test_exception_scoped_metrics = [ - ("GraphQL/operation/GraphQL/query/MyQuery/%s" % field, 1), - ("GraphQL/resolve/GraphQL/%s" % field, 1), + ("GraphQL/operation/%s/query/MyQuery/%s" % (framework, field), 1), + ("GraphQL/resolve/%s/%s" % (framework, field), 1), ] _test_exception_rollup_metrics = [ ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/%s" % txn_name, 1), + ("Errors/all%s" % ("Other" if is_bg else "Web"), 1), + ("Errors/%sTransaction/GraphQL/%s" % ("Other" if is_bg else "Web", txn_name), 1), ] + _test_exception_scoped_metrics # Attributes @@ -298,16 +339,15 @@ def test_exception_in_resolver(app, graphql_run, field): txn_name, "GraphQL", scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, + rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, ) @validate_span_events(exact_agents=_expected_exception_operation_attributes) @validate_span_events(exact_agents=_expected_exception_resolver_attributes) @validate_transaction_errors(errors=_test_runtime_error) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, query) - assert response.errors + response = target_application(query) _test() @@ -316,18 +356,16 @@ def _test(): @pytest.mark.parametrize( "query,exc_class", [ - ("query MyQuery { missing_field }", "GraphQLError"), + ("query MyQuery { error_missing_field }", "GraphQLError"), ("{ syntax_error ", "graphql.error.syntax_error:GraphQLSyntaxError"), ], ) -def test_exception_in_validation(app, graphql_run, is_graphql_2, query, exc_class): +def test_exception_in_validation(target_application, query, exc_class): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application if "syntax" in query: txn_name = "graphql.language.parser:parse" else: - if is_graphql_2: - txn_name = "graphql.validation.validation:validate" - else: - txn_name = "graphql.validation.validate:validate" + txn_name = "graphql.validation.validate:validate" # Import path differs between versions if exc_class == "GraphQLError": @@ -336,12 +374,12 @@ def test_exception_in_validation(app, graphql_run, is_graphql_2, query, exc_clas exc_class = callable_name(GraphQLError) _test_exception_scoped_metrics = [ - # ('GraphQL/operation/GraphQL///', 1), + ("GraphQL/operation/%s///" % framework, 1), ] _test_exception_rollup_metrics = [ ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/%s" % txn_name, 1), + ("Errors/all%s" % ("Other" if is_bg else "Web"), 1), + ("Errors/%sTransaction/GraphQL/%s" % ("Other" if is_bg else "Web", txn_name), 1), ] + _test_exception_scoped_metrics # Attributes @@ -355,72 +393,77 @@ def test_exception_in_validation(app, graphql_run, is_graphql_2, query, exc_clas txn_name, "GraphQL", scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, + rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, ) @validate_span_events(exact_agents=_expected_exception_operation_attributes) @validate_transaction_errors(errors=[exc_class]) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, query) - assert response.errors + response = target_application(query) _test() @dt_enabled -def test_operation_metrics_and_attrs(app, graphql_run): - operation_metrics = [("GraphQL/operation/GraphQL/query/MyQuery/library", 1)] +def test_operation_metrics_and_attrs(target_application): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application + operation_metrics = [("GraphQL/operation/%s/query/MyQuery/library" % framework, 1)] operation_attrs = { "graphql.operation.type": "query", "graphql.operation.name": "MyQuery", } + # Span count 16: Transaction, Operation, and 7 Resolvers and Resolver functions + # library, library.name, library.book + # library.book.name and library.book.id for each book resolved (in this case 2) + span_count = 16 + extra_spans # WSGI may add 4 spans, other frameworks may add other amounts + @validate_transaction_metrics( "query/MyQuery/library", "GraphQL", scoped_metrics=operation_metrics, - rollup_metrics=operation_metrics + _graphql_base_rollup_metrics, - background_task=True, + rollup_metrics=operation_metrics + _graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, ) - # Span count 16: Transaction, Operation, and 7 Resolvers and Resolver functions - # library, library.name, library.book - # library.book.name and library.book.id for each book resolved (in this case 2) - @validate_span_events(count=16) + @validate_span_events(count=span_count) @validate_span_events(exact_agents=operation_attrs) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, "query MyQuery { library(index: 0) { branch, book { id, name } } }") - assert not response.errors + response = target_application("query MyQuery { library(index: 0) { branch, book { id, name } } }") _test() @dt_enabled -def test_field_resolver_metrics_and_attrs(app, graphql_run): - field_resolver_metrics = [("GraphQL/resolve/GraphQL/hello", 1)] +def test_field_resolver_metrics_and_attrs(target_application): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application + field_resolver_metrics = [("GraphQL/resolve/%s/hello" % framework, 1)] + + type_annotation = "!" if framework == "Strawberry" else "" graphql_attrs = { "graphql.field.name": "hello", "graphql.field.parentType": "Query", "graphql.field.path": "hello", - "graphql.field.returnType": "String", + "graphql.field.returnType": "String" + type_annotation, } + # Span count 4: Transaction, Operation, and 1 Resolver and Resolver function + span_count = 4 + extra_spans # WSGI may add 4 spans, other frameworks may add other amounts + @validate_transaction_metrics( "query//hello", "GraphQL", scoped_metrics=field_resolver_metrics, - rollup_metrics=field_resolver_metrics + _graphql_base_rollup_metrics, - background_task=True, + rollup_metrics=field_resolver_metrics + _graphql_base_rollup_metrics(framework, version, is_bg), + background_task=is_bg, ) - # Span count 4: Transaction, Operation, and 1 Resolver and Resolver function - @validate_span_events(count=4) + @validate_span_events(count=span_count) @validate_span_events(exact_agents=graphql_attrs) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, "{ hello }") - assert not response.errors - assert "Hello!" in str(response.data) + response = target_application("{ hello }") + assert response["hello"] == "Hello!" _test() @@ -443,18 +486,19 @@ def _test(): @dt_enabled @pytest.mark.parametrize("query,obfuscated", _test_queries) -def test_query_obfuscation(app, graphql_run, query, obfuscated): +def test_query_obfuscation(target_application, query, obfuscated): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application graphql_attrs = {"graphql.operation.query": obfuscated} if callable(query): + if framework != "GraphQL": + pytest.skip("Source query objects not tested outside of graphql-core") query = query() @validate_span_events(exact_agents=graphql_attrs) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, query) - if not isinstance(query, str) or "error" not in query: - assert not response.errors + response = target_application(query) _test() @@ -499,28 +543,28 @@ def _test(): @dt_enabled @pytest.mark.parametrize("query,expected_path", _test_queries) -def test_deepest_unique_path(app, graphql_run, query, expected_path): +def test_deepest_unique_path(target_application, query, expected_path): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application if expected_path == "/error": - txn_name = "_target_application:resolve_error" + txn_name = "framework_%s._target_schema_%s:resolve_error" % (framework.lower(), schema_type) else: txn_name = "query/%s" % expected_path @validate_transaction_metrics( txn_name, "GraphQL", - background_task=True, + background_task=is_bg, ) - @background_task() + @conditional_decorator(background_task(), is_bg) def _test(): - response = graphql_run(app, query) - if "error" not in query: - assert not response.errors + response = target_application(query) _test() @pytest.mark.parametrize("capture_introspection_setting", (True, False)) -def test_introspection_transactions(app, graphql_run, capture_introspection_setting): +def test_introspection_transactions(target_application, capture_introspection_setting): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application txn_ct = 1 if capture_introspection_setting else 0 @override_application_settings( @@ -529,7 +573,6 @@ def test_introspection_transactions(app, graphql_run, capture_introspection_sett @validate_transaction_count(txn_ct) @background_task() def _test(): - response = graphql_run(app, "{ __schema { types { name } } }") - assert not response.errors + response = target_application("{ __schema { types { name } } }") _test() diff --git a/tests/framework_graphql/test_application_async.py b/tests/framework_graphql/test_application_async.py index 28b435c43..39c1871ef 100644 --- a/tests/framework_graphql/test_application_async.py +++ b/tests/framework_graphql/test_application_async.py @@ -12,99 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio +from inspect import isawaitable -import pytest -from test_application import is_graphql_2 -from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_span_events import validate_span_events -from newrelic.api.background_task import background_task +# Async Functions not allowed in Py2 +async def example_middleware_async(next, root, info, **args): + return_value = next(root, info, **args) + if isawaitable(return_value): + return await return_value + return return_value -@pytest.fixture(scope="session") -def graphql_run_async(): - from graphql import __version__ as version - from graphql import graphql - - major_version = int(version.split(".")[0]) - if major_version == 2: - - def graphql_run(*args, **kwargs): - return graphql(*args, return_promise=True, **kwargs) - - return graphql_run - else: - return graphql - - -@dt_enabled -def test_query_and_mutation_async(app, graphql_run_async, is_graphql_2): - from graphql import __version__ as version - - FRAMEWORK_METRICS = [ - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/GraphQL/storage", 1), - ("GraphQL/resolve/GraphQL/storage_add", 1), - ("GraphQL/operation/GraphQL/query//storage", 1), - ("GraphQL/operation/GraphQL/mutation//storage_add", 1), - ] - _test_mutation_unscoped_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 2), - ("GraphQL/GraphQL/all", 2), - ("GraphQL/allOther", 2), - ("GraphQL/GraphQL/allOther", 2), - ] + _test_mutation_scoped_metrics - - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "[String]" if is_graphql_2 else "String", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String]", - } - - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @background_task() - def _test(): - async def coro(): - response = await graphql_run_async(app, 'mutation { storage_add(string: "abc") }') - assert not response.errors - response = await graphql_run_async(app, "query { storage }") - assert not response.errors - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.data) - assert "abc" in str(response.data) - - loop = asyncio.new_event_loop() - loop.run_until_complete(coro()) - - _test() +async def error_middleware_async(next, root, info, **args): + raise RuntimeError("Runtime Error!") diff --git a/tests/framework_starlette/test_graphql.py b/tests/framework_starlette/test_graphql.py deleted file mode 100644 index 24ec3ab38..000000000 --- a/tests/framework_starlette/test_graphql.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -import pytest -from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_span_events import validate_span_events - - -def get_starlette_version(): - import starlette - - version = getattr(starlette, "__version__", "0.0.0").split(".") - return tuple(int(x) for x in version) - - -@pytest.fixture(scope="session") -def target_application(): - import _test_graphql - - return _test_graphql.target_application - - -@dt_enabled -@pytest.mark.parametrize("endpoint", ("/async", "/sync")) -@pytest.mark.skipif(get_starlette_version() >= (0, 17), reason="Starlette GraphQL support dropped in v0.17.0") -def test_graphql_metrics_and_attrs(target_application, endpoint): - from graphql import __version__ as version - - from newrelic.hooks.framework_graphene import framework_details - - FRAMEWORK_METRICS = [ - ("Python/Framework/Graphene/%s" % framework_details()[1], 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_scoped_metrics = [ - ("GraphQL/resolve/Graphene/hello", 1), - ("GraphQL/operation/Graphene/query//hello", 1), - ] - _test_unscoped_metrics = [ - ("GraphQL/all", 1), - ("GraphQL/Graphene/all", 1), - ("GraphQL/allWeb", 1), - ("GraphQL/Graphene/allWeb", 1), - ] + _test_scoped_metrics - - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - "graphql.operation.query": "{ hello }", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "hello", - "graphql.field.parentType": "Query", - "graphql.field.path": "hello", - "graphql.field.returnType": "String", - } - - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @validate_transaction_metrics( - "query//hello", - "GraphQL", - scoped_metrics=_test_scoped_metrics, - rollup_metrics=_test_unscoped_metrics + FRAMEWORK_METRICS, - ) - def _test(): - response = target_application.make_request( - "POST", endpoint, body=json.dumps({"query": "{ hello }"}), headers={"Content-Type": "application/json"} - ) - assert response.status == 200 - assert "Hello!" in response.body.decode("utf-8") - - _test() diff --git a/tests/framework_strawberry/__init__.py b/tests/framework_strawberry/__init__.py new file mode 100644 index 000000000..8030baccf --- /dev/null +++ b/tests/framework_strawberry/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/framework_strawberry/_target_application.py b/tests/framework_strawberry/_target_application.py index e032fc27a..afba04873 100644 --- a/tests/framework_strawberry/_target_application.py +++ b/tests/framework_strawberry/_target_application.py @@ -12,185 +12,90 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union - -import strawberry.mutation -import strawberry.type -from strawberry import Schema, field -from strawberry.asgi import GraphQL -from strawberry.schema.config import StrawberryConfig -from strawberry.types.types import Optional - - -@strawberry.type -class Author: - first_name: str - last_name: str - - -@strawberry.type -class Book: - id: int - name: str - isbn: str - author: Author - branch: str - - -@strawberry.type -class Magazine: - id: int - name: str - issue: int - branch: str - -@strawberry.type -class Library: - id: int - branch: str - magazine: List[Magazine] - book: List[Book] +import asyncio +import json +import pytest +from framework_strawberry._target_schema_async import ( + target_asgi_application as target_asgi_application_async, +) +from framework_strawberry._target_schema_async import ( + target_schema as target_schema_async, +) +from framework_strawberry._target_schema_sync import ( + target_asgi_application as target_asgi_application_sync, +) +from framework_strawberry._target_schema_sync import target_schema as target_schema_sync -Item = Union[Book, Magazine] -Storage = List[str] +def run_sync(schema): + def _run_sync(query, middleware=None): + from graphql.language.source import Source -authors = [ - Author( - first_name="New", - last_name="Relic", - ), - Author( - first_name="Bob", - last_name="Smith", - ), - Author( - first_name="Leslie", - last_name="Jones", - ), -] - -books = [ - Book( - id=1, - name="Python Agent: The Book", - isbn="a-fake-isbn", - author=authors[0], - branch="riverside", - ), - Book( - id=2, - name="Ollies for O11y: A Sk8er's Guide to Observability", - isbn="a-second-fake-isbn", - author=authors[1], - branch="downtown", - ), - Book( - id=3, - name="[Redacted]", - isbn="a-third-fake-isbn", - author=authors[2], - branch="riverside", - ), -] + if middleware is not None: + pytest.skip("Middleware not supported in Strawberry.") -magazines = [ - Magazine(id=1, name="Reli Updates Weekly", issue=1, branch="riverside"), - Magazine(id=2, name="Reli: The Forgotten Years", issue=2, branch="downtown"), - Magazine(id=3, name="Node Weekly", issue=1, branch="riverside"), -] + response = schema.execute_sync(query) + if isinstance(query, str) and "error" not in query or isinstance(query, Source) and "error" not in query.body: + assert not response.errors + else: + assert response.errors -libraries = ["riverside", "downtown"] -libraries = [ - Library( - id=i + 1, - branch=branch, - magazine=[m for m in magazines if m.branch == branch], - book=[b for b in books if b.branch == branch], - ) - for i, branch in enumerate(libraries) -] + return response.data -storage = [] + return _run_sync -def resolve_hello(): - return "Hello!" +def run_async(schema): + def _run_async(query, middleware=None): + from graphql.language.source import Source + if middleware is not None: + pytest.skip("Middleware not supported in Strawberry.") -async def resolve_hello_async(): - return "Hello!" + loop = asyncio.get_event_loop() + response = loop.run_until_complete(schema.execute(query)) + if isinstance(query, str) and "error" not in query or isinstance(query, Source) and "error" not in query.body: + assert not response.errors + else: + assert response.errors -def resolve_echo(echo: str): - return echo + return response.data + return _run_async -def resolve_library(index: int): - return libraries[index] +def run_asgi(app): + def _run_asgi(query, middleware=None): + if middleware is not None: + pytest.skip("Middleware not supported in Strawberry.") -def resolve_storage_add(string: str): - storage.add(string) - return storage + response = app.make_request( + "POST", "/", body=json.dumps({"query": query}), headers={"Content-Type": "application/json"} + ) + body = json.loads(response.body.decode("utf-8")) + if not isinstance(query, str) or "error" in query: + try: + assert response.status != 200 + except AssertionError: + assert body["errors"] + else: + assert response.status == 200 + assert "errors" not in body or not body["errors"] -def resolve_storage(): - return storage + return body["data"] + return _run_asgi -def resolve_error(): - raise RuntimeError("Runtime Error!") - -def resolve_search(contains: str): - search_books = [b for b in books if contains in b.name] - search_magazines = [m for m in magazines if contains in m.name] - return search_books + search_magazines - - -@strawberry.type -class Query: - library: Library = field(resolver=resolve_library) - hello: str = field(resolver=resolve_hello) - hello_async: str = field(resolver=resolve_hello_async) - search: List[Item] = field(resolver=resolve_search) - echo: str = field(resolver=resolve_echo) - storage: Storage = field(resolver=resolve_storage) - error: Optional[str] = field(resolver=resolve_error) - error_non_null: str = field(resolver=resolve_error) - - def resolve_library(self, info, index): - return libraries[index] - - def resolve_storage(self, info): - return storage - - def resolve_search(self, info, contains): - search_books = [b for b in books if contains in b.name] - search_magazines = [m for m in magazines if contains in m.name] - return search_books + search_magazines - - def resolve_hello(self, info): - return "Hello!" - - def resolve_echo(self, info, echo): - return echo - - def resolve_error(self, info) -> str: - raise RuntimeError("Runtime Error!") - - -@strawberry.type -class Mutation: - @strawberry.mutation - def storage_add(self, string: str) -> str: - storage.append(string) - return str(string) - - -_target_application = Schema(query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False)) -_target_asgi_application = GraphQL(_target_application) +target_application = { + "sync-sync": run_sync(target_schema_sync), + "async-sync": run_async(target_schema_sync), + "asgi-sync": run_asgi(target_asgi_application_sync), + "async-async": run_async(target_schema_async), + "asgi-async": run_asgi(target_asgi_application_async), +} diff --git a/tests/framework_strawberry/_target_schema_async.py b/tests/framework_strawberry/_target_schema_async.py new file mode 100644 index 000000000..373cef537 --- /dev/null +++ b/tests/framework_strawberry/_target_schema_async.py @@ -0,0 +1,84 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import strawberry.mutation +import strawberry.type +from framework_strawberry._target_schema_sync import ( + Item, + Library, + Storage, + books, + libraries, + magazines, +) +from strawberry import Schema, field +from strawberry.asgi import GraphQL +from strawberry.schema.config import StrawberryConfig +from strawberry.types.types import Optional +from testing_support.asgi_testing import AsgiTest + +storage = [] + + +async def resolve_hello(): + return "Hello!" + + +async def resolve_echo(echo: str): + return echo + + +async def resolve_library(index: int): + return libraries[index] + + +async def resolve_storage_add(string: str): + storage.append(string) + return string + + +async def resolve_storage(): + return [storage.pop()] + + +async def resolve_error(): + raise RuntimeError("Runtime Error!") + + +async def resolve_search(contains: str): + search_books = [b for b in books if contains in b.name] + search_magazines = [m for m in magazines if contains in m.name] + return search_books + search_magazines + + +@strawberry.type +class Query: + library: Library = field(resolver=resolve_library) + hello: str = field(resolver=resolve_hello) + search: List[Item] = field(resolver=resolve_search) + echo: str = field(resolver=resolve_echo) + storage: Storage = field(resolver=resolve_storage) + error: Optional[str] = field(resolver=resolve_error) + error_non_null: str = field(resolver=resolve_error) + + +@strawberry.type +class Mutation: + storage_add: str = strawberry.mutation(resolver=resolve_storage_add) + + +target_schema = Schema(query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False)) +target_asgi_application = AsgiTest(GraphQL(target_schema)) diff --git a/tests/framework_strawberry/_target_schema_sync.py b/tests/framework_strawberry/_target_schema_sync.py new file mode 100644 index 000000000..34bff75b9 --- /dev/null +++ b/tests/framework_strawberry/_target_schema_sync.py @@ -0,0 +1,169 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +import strawberry.mutation +import strawberry.type +from strawberry import Schema, field +from strawberry.asgi import GraphQL +from strawberry.schema.config import StrawberryConfig +from strawberry.types.types import Optional +from testing_support.asgi_testing import AsgiTest + + +@strawberry.type +class Author: + first_name: str + last_name: str + + +@strawberry.type +class Book: + id: int + name: str + isbn: str + author: Author + branch: str + + +@strawberry.type +class Magazine: + id: int + name: str + issue: int + branch: str + + +@strawberry.type +class Library: + id: int + branch: str + magazine: List[Magazine] + book: List[Book] + + +Item = Union[Book, Magazine] +Storage = List[str] + + +authors = [ + Author( + first_name="New", + last_name="Relic", + ), + Author( + first_name="Bob", + last_name="Smith", + ), + Author( + first_name="Leslie", + last_name="Jones", + ), +] + +books = [ + Book( + id=1, + name="Python Agent: The Book", + isbn="a-fake-isbn", + author=authors[0], + branch="riverside", + ), + Book( + id=2, + name="Ollies for O11y: A Sk8er's Guide to Observability", + isbn="a-second-fake-isbn", + author=authors[1], + branch="downtown", + ), + Book( + id=3, + name="[Redacted]", + isbn="a-third-fake-isbn", + author=authors[2], + branch="riverside", + ), +] + +magazines = [ + Magazine(id=1, name="Reli Updates Weekly", issue=1, branch="riverside"), + Magazine(id=2, name="Reli: The Forgotten Years", issue=2, branch="downtown"), + Magazine(id=3, name="Node Weekly", issue=1, branch="riverside"), +] + + +libraries = ["riverside", "downtown"] +libraries = [ + Library( + id=i + 1, + branch=branch, + magazine=[m for m in magazines if m.branch == branch], + book=[b for b in books if b.branch == branch], + ) + for i, branch in enumerate(libraries) +] + +storage = [] + + +def resolve_hello(): + return "Hello!" + + +def resolve_echo(echo: str): + return echo + + +def resolve_library(index: int): + return libraries[index] + + +def resolve_storage_add(string: str): + storage.append(string) + return string + + +def resolve_storage(): + return [storage.pop()] + + +def resolve_error(): + raise RuntimeError("Runtime Error!") + + +def resolve_search(contains: str): + search_books = [b for b in books if contains in b.name] + search_magazines = [m for m in magazines if contains in m.name] + return search_books + search_magazines + + +@strawberry.type +class Query: + library: Library = field(resolver=resolve_library) + hello: str = field(resolver=resolve_hello) + search: List[Item] = field(resolver=resolve_search) + echo: str = field(resolver=resolve_echo) + storage: Storage = field(resolver=resolve_storage) + error: Optional[str] = field(resolver=resolve_error) + error_non_null: str = field(resolver=resolve_error) + + +@strawberry.type +class Mutation: + storage_add: str = strawberry.mutation(resolver=resolve_storage_add) + + +target_schema = Schema(query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False)) +target_asgi_application = AsgiTest(GraphQL(target_schema)) diff --git a/tests/framework_strawberry/conftest.py b/tests/framework_strawberry/conftest.py index 130866bcb..6345b3033 100644 --- a/tests/framework_strawberry/conftest.py +++ b/tests/framework_strawberry/conftest.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest -import six - -from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture # noqa: F401; pylint: disable=W0611 - +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) _default_settings = { "transaction_tracer.explain_threshold": 0.0, @@ -30,14 +29,3 @@ app_name="Python Agent Test (framework_strawberry)", default_settings=_default_settings, ) - - -@pytest.fixture(scope="session") -def app(): - from _target_application import _target_application - - return _target_application - - -if six.PY2: - collect_ignore = ["test_application_async.py"] diff --git a/tests/framework_strawberry/test_application.py b/tests/framework_strawberry/test_application.py index ac60a33e0..5a3f579ba 100644 --- a/tests/framework_strawberry/test_application.py +++ b/tests/framework_strawberry/test_application.py @@ -11,437 +11,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import pytest -from testing_support.fixtures import dt_enabled, override_application_settings -from testing_support.validators.validate_span_events import validate_span_events +from framework_graphql.test_application import * +from testing_support.fixtures import override_application_settings from testing_support.validators.validate_transaction_count import ( validate_transaction_count, ) -from testing_support.validators.validate_transaction_errors import ( - validate_transaction_errors, -) -from testing_support.validators.validate_transaction_metrics import ( - validate_transaction_metrics, -) from newrelic.api.background_task import background_task -from newrelic.common.object_names import callable_name - - -@pytest.fixture(scope="session") -def is_graphql_2(): - from graphql import __version__ as version - - major_version = int(version.split(".")[0]) - return major_version == 2 - - -@pytest.fixture(scope="session") -def graphql_run(): - """Wrapper function to simulate framework_graphql test behavior.""" - - def execute(schema, *args, **kwargs): - return schema.execute_sync(*args, **kwargs) - - return execute - - -def to_graphql_source(query): - def delay_import(): - try: - from graphql import Source - except ImportError: - # Fallback if Source is not implemented - return query - - from graphql import __version__ as version - - # For graphql2, Source objects aren't acceptable input - major_version = int(version.split(".")[0]) - if major_version == 2: - return query - - return Source(query) - - return delay_import - - -def example_middleware(next, root, info, **args): # pylint: disable=W0622 - return_value = next(root, info, **args) - return return_value - - -def error_middleware(next, root, info, **args): # pylint: disable=W0622 - raise RuntimeError("Runtime Error!") - - -_runtime_error_name = callable_name(RuntimeError) -_test_runtime_error = [(_runtime_error_name, "Runtime Error!")] -_graphql_base_rollup_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 1), - ("GraphQL/allOther", 1), - ("GraphQL/Strawberry/all", 1), - ("GraphQL/Strawberry/allOther", 1), -] - - -def test_basic(app, graphql_run): - from graphql import __version__ as version - - from newrelic.hooks.framework_strawberry import framework_details +from newrelic.common.package_version_utils import get_package_version - FRAMEWORK_METRICS = [ - ("Python/Framework/Strawberry/%s" % framework_details()[1], 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] +STRAWBERRY_VERSION = get_package_version("strawberry-graphql") - @validate_transaction_metrics( - "query//hello", - "GraphQL", - rollup_metrics=_graphql_base_rollup_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @background_task() - def _test(): - response = graphql_run(app, "{ hello }") - assert not response.errors - - _test() - - -@dt_enabled -def test_query_and_mutation(app, graphql_run): - from graphql import __version__ as version - from newrelic.hooks.framework_strawberry import framework_details +@pytest.fixture(scope="session", params=["sync-sync", "async-sync", "async-async", "asgi-sync", "asgi-async"]) +def target_application(request): + from ._target_application import target_application - FRAMEWORK_METRICS = [ - ("Python/Framework/Strawberry/%s" % framework_details()[1], 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/Strawberry/storage", 1), - ("GraphQL/resolve/Strawberry/storage_add", 1), - ("GraphQL/operation/Strawberry/query//storage", 1), - ("GraphQL/operation/Strawberry/mutation//storage_add", 1), - ] - _test_mutation_unscoped_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 2), - ("GraphQL/Strawberry/all", 2), - ("GraphQL/allOther", 2), - ("GraphQL/Strawberry/allOther", 2), - ] + _test_mutation_scoped_metrics + target_application = target_application[request.param] - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "String!", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String!]!", - } + is_asgi = "asgi" in request.param + schema_type = request.param.split("-")[1] - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @background_task() - def _test(): - response = graphql_run(app, 'mutation { storage_add(string: "abc") }') - assert not response.errors - response = graphql_run(app, "query { storage }") - assert not response.errors - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.data) - assert "abc" in str(response.data) - - _test() - - -@pytest.mark.parametrize("field", ("error", "error_non_null")) -@dt_enabled -def test_exception_in_resolver(app, graphql_run, field): - query = "query MyQuery { %s }" % field - - txn_name = "_target_application:resolve_error" - - # Metrics - _test_exception_scoped_metrics = [ - ("GraphQL/operation/Strawberry/query/MyQuery/%s" % field, 1), - ("GraphQL/resolve/Strawberry/%s" % field, 1), - ] - _test_exception_rollup_metrics = [ - ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/%s" % txn_name, 1), - ] + _test_exception_scoped_metrics - - # Attributes - _expected_exception_resolver_attributes = { - "graphql.field.name": field, - "graphql.field.parentType": "Query", - "graphql.field.path": field, - "graphql.field.returnType": "String!" if "non_null" in field else "String", - } - _expected_exception_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "MyQuery", - "graphql.operation.query": query, - } - - @validate_transaction_metrics( - txn_name, - "GraphQL", - scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_exception_operation_attributes) - @validate_span_events(exact_agents=_expected_exception_resolver_attributes) - @validate_transaction_errors(errors=_test_runtime_error) - @background_task() - def _test(): - response = graphql_run(app, query) - assert response.errors - - _test() - - -@dt_enabled -@pytest.mark.parametrize( - "query,exc_class", - [ - ("query MyQuery { missing_field }", "GraphQLError"), - ("{ syntax_error ", "graphql.error.syntax_error:GraphQLSyntaxError"), - ], -) -def test_exception_in_validation(app, graphql_run, is_graphql_2, query, exc_class): - if "syntax" in query: - txn_name = "graphql.language.parser:parse" - else: - if is_graphql_2: - txn_name = "graphql.validation.validation:validate" - else: - txn_name = "graphql.validation.validate:validate" - - # Import path differs between versions - if exc_class == "GraphQLError": - from graphql.error import GraphQLError - - exc_class = callable_name(GraphQLError) - - _test_exception_scoped_metrics = [ - ("GraphQL/operation/Strawberry///", 1), - ] - _test_exception_rollup_metrics = [ - ("Errors/all", 1), - ("Errors/allOther", 1), - ("Errors/OtherTransaction/GraphQL/%s" % txn_name, 1), - ] + _test_exception_scoped_metrics - - # Attributes - _expected_exception_operation_attributes = { - "graphql.operation.type": "", - "graphql.operation.name": "", - "graphql.operation.query": query, - } - - @validate_transaction_metrics( - txn_name, - "GraphQL", - scoped_metrics=_test_exception_scoped_metrics, - rollup_metrics=_test_exception_rollup_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_exception_operation_attributes) - @validate_transaction_errors(errors=[exc_class]) - @background_task() - def _test(): - response = graphql_run(app, query) - assert response.errors - - _test() - - -@dt_enabled -def test_operation_metrics_and_attrs(app, graphql_run): - operation_metrics = [("GraphQL/operation/Strawberry/query/MyQuery/library", 1)] - operation_attrs = { - "graphql.operation.type": "query", - "graphql.operation.name": "MyQuery", - } - - @validate_transaction_metrics( - "query/MyQuery/library", - "GraphQL", - scoped_metrics=operation_metrics, - rollup_metrics=operation_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - # Span count 16: Transaction, Operation, and 7 Resolvers and Resolver functions - # library, library.name, library.book - # library.book.name and library.book.id for each book resolved (in this case 2) - @validate_span_events(count=16) - @validate_span_events(exact_agents=operation_attrs) - @background_task() - def _test(): - response = graphql_run(app, "query MyQuery { library(index: 0) { branch, book { id, name } } }") - assert not response.errors - - _test() - - -@dt_enabled -def test_field_resolver_metrics_and_attrs(app, graphql_run): - field_resolver_metrics = [("GraphQL/resolve/Strawberry/hello", 1)] - graphql_attrs = { - "graphql.field.name": "hello", - "graphql.field.parentType": "Query", - "graphql.field.path": "hello", - "graphql.field.returnType": "String!", - } - - @validate_transaction_metrics( - "query//hello", - "GraphQL", - scoped_metrics=field_resolver_metrics, - rollup_metrics=field_resolver_metrics + _graphql_base_rollup_metrics, - background_task=True, - ) - # Span count 4: Transaction, Operation, and 1 Resolver and Resolver function - @validate_span_events(count=4) - @validate_span_events(exact_agents=graphql_attrs) - @background_task() - def _test(): - response = graphql_run(app, "{ hello }") - assert not response.errors - assert "Hello!" in str(response.data) - - _test() - - -_test_queries = [ - ("{ hello }", "{ hello }"), # Basic query extraction - ("{ error }", "{ error }"), # Extract query on field error - (to_graphql_source("{ hello }"), "{ hello }"), # Extract query from Source objects - ( - "{ library(index: 0) { branch } }", - "{ library(index: ?) { branch } }", - ), # Integers - ('{ echo(echo: "123") }', "{ echo(echo: ?) }"), # Strings with numerics - ('{ echo(echo: "test") }', "{ echo(echo: ?) }"), # Strings - ('{ TestEcho: echo(echo: "test") }', "{ TestEcho: echo(echo: ?) }"), # Aliases - ('{ TestEcho: echo(echo: "test") }', "{ TestEcho: echo(echo: ?) }"), # Variables - ( # Fragments - '{ ...MyFragment } fragment MyFragment on Query { echo(echo: "test") }', - "{ ...MyFragment } fragment MyFragment on Query { echo(echo: ?) }", - ), -] - - -@dt_enabled -@pytest.mark.parametrize("query,obfuscated", _test_queries) -def test_query_obfuscation(app, graphql_run, query, obfuscated): - graphql_attrs = {"graphql.operation.query": obfuscated} - - if callable(query): - query = query() - - @validate_span_events(exact_agents=graphql_attrs) - @background_task() - def _test(): - response = graphql_run(app, query) - if not isinstance(query, str) or "error" not in query: - assert not response.errors - - _test() - - -_test_queries = [ - ("{ hello }", "/hello"), # Basic query - ("{ error }", "/error"), # Extract deepest path on field error - ('{ echo(echo: "test") }', "/echo"), # Fields with arguments - ( - "{ library(index: 0) { branch, book { isbn branch } } }", - "/library", - ), # Complex Example, 1 level - ( - "{ library(index: 0) { book { author { first_name }} } }", - "/library.book.author.first_name", - ), # Complex Example, 2 levels - ("{ library(index: 0) { id, book { name } } }", "/library.book.name"), # Filtering - ('{ TestEcho: echo(echo: "test") }', "/echo"), # Aliases - ( - '{ search(contains: "A") { __typename ... on Book { name } } }', - "/search.name", - ), # InlineFragment - ( - '{ hello echo(echo: "test") }', - "", - ), # Multiple root selections. (need to decide on final behavior) - # FragmentSpread - ( - "{ library(index: 0) { book { ...MyFragment } } } fragment MyFragment on Book { name id }", # Fragment filtering - "/library.book.name", - ), - ( - "{ library(index: 0) { book { ...MyFragment } } } fragment MyFragment on Book { author { first_name } }", - "/library.book.author.first_name", - ), - ( - "{ library(index: 0) { book { ...MyFragment } magazine { ...MagFragment } } } fragment MyFragment on Book { author { first_name } } fragment MagFragment on Magazine { name }", - "/library", - ), -] - - -@dt_enabled -@pytest.mark.parametrize("query,expected_path", _test_queries) -def test_deepest_unique_path(app, graphql_run, query, expected_path): - if expected_path == "/error": - txn_name = "_target_application:resolve_error" - else: - txn_name = "query/%s" % expected_path - - @validate_transaction_metrics( - txn_name, - "GraphQL", - background_task=True, - ) - @background_task() - def _test(): - response = graphql_run(app, query) - if "error" not in query: - assert not response.errors - - _test() + assert STRAWBERRY_VERSION is not None + return "Strawberry", STRAWBERRY_VERSION, target_application, not is_asgi, schema_type, 0 @pytest.mark.parametrize("capture_introspection_setting", (True, False)) -def test_introspection_transactions(app, graphql_run, capture_introspection_setting): +def test_introspection_transactions(target_application, capture_introspection_setting): + framework, version, target_application, is_bg, schema_type, extra_spans = target_application + txn_ct = 1 if capture_introspection_setting else 0 @override_application_settings( @@ -450,7 +49,6 @@ def test_introspection_transactions(app, graphql_run, capture_introspection_sett @validate_transaction_count(txn_ct) @background_task() def _test(): - response = graphql_run(app, "{ __schema { types { name } } }") - assert not response.errors + response = target_application("{ __schema { types { name } } }") _test() diff --git a/tests/framework_strawberry/test_application_async.py b/tests/framework_strawberry/test_application_async.py deleted file mode 100644 index 1354c4c01..000000000 --- a/tests/framework_strawberry/test_application_async.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio - -import pytest -from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_span_events import validate_span_events - -from newrelic.api.background_task import background_task - - -@pytest.fixture(scope="session") -def graphql_run_async(): - """Wrapper function to simulate framework_graphql test behavior.""" - - def execute(schema, *args, **kwargs): - return schema.execute(*args, **kwargs) - - return execute - - -_graphql_base_rollup_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 1), - ("GraphQL/allOther", 1), - ("GraphQL/Strawberry/all", 1), - ("GraphQL/Strawberry/allOther", 1), -] - - -loop = asyncio.new_event_loop() - - -def test_basic(app, graphql_run_async): - from graphql import __version__ as version - - from newrelic.hooks.framework_strawberry import framework_details - - FRAMEWORK_METRICS = [ - ("Python/Framework/Strawberry/%s" % framework_details()[1], 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - - @validate_transaction_metrics( - "query//hello_async", - "GraphQL", - rollup_metrics=_graphql_base_rollup_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @background_task() - def _test(): - async def coro(): - response = await graphql_run_async(app, "{ hello_async }") - assert not response.errors - - loop.run_until_complete(coro()) - - _test() - - -@dt_enabled -def test_query_and_mutation_async(app, graphql_run_async): - from graphql import __version__ as version - - from newrelic.hooks.framework_strawberry import framework_details - - FRAMEWORK_METRICS = [ - ("Python/Framework/Strawberry/%s" % framework_details()[1], 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/Strawberry/storage", 1), - ("GraphQL/resolve/Strawberry/storage_add", 1), - ("GraphQL/operation/Strawberry/query//storage", 1), - ("GraphQL/operation/Strawberry/mutation//storage_add", 1), - ] - _test_mutation_unscoped_metrics = [ - ("OtherTransaction/all", 1), - ("GraphQL/all", 2), - ("GraphQL/Strawberry/all", 2), - ("GraphQL/allOther", 2), - ("GraphQL/Strawberry/allOther", 2), - ] + _test_mutation_scoped_metrics - - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "String!", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String!]!", - } - - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - background_task=True, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @background_task() - def _test(): - async def coro(): - response = await graphql_run_async(app, 'mutation { storage_add(string: "abc") }') - assert not response.errors - response = await graphql_run_async(app, "query { storage }") - assert not response.errors - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.data) - assert "abc" in str(response.data) - - loop.run_until_complete(coro()) - - _test() diff --git a/tests/framework_strawberry/test_asgi.py b/tests/framework_strawberry/test_asgi.py deleted file mode 100644 index 8acbaedfb..000000000 --- a/tests/framework_strawberry/test_asgi.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -import pytest -from testing_support.asgi_testing import AsgiTest -from testing_support.fixtures import dt_enabled -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics -from testing_support.validators.validate_span_events import validate_span_events - - -@pytest.fixture(scope="session") -def graphql_asgi_run(): - """Wrapper function to simulate framework_graphql test behavior.""" - from _target_application import _target_asgi_application - - app = AsgiTest(_target_asgi_application) - - def execute(query): - return app.make_request( - "POST", - "/", - headers={"Content-Type": "application/json"}, - body=json.dumps({"query": query}), - ) - - return execute - - -@dt_enabled -def test_query_and_mutation_asgi(graphql_asgi_run): - from graphql import __version__ as version - - from newrelic.hooks.framework_strawberry import framework_details - - FRAMEWORK_METRICS = [ - ("Python/Framework/Strawberry/%s" % framework_details()[1], 1), - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_mutation_scoped_metrics = [ - ("GraphQL/resolve/Strawberry/storage_add", 1), - ("GraphQL/operation/Strawberry/mutation//storage_add", 1), - ] - _test_query_scoped_metrics = [ - ("GraphQL/resolve/Strawberry/storage", 1), - ("GraphQL/operation/Strawberry/query//storage", 1), - ] - _test_unscoped_metrics = [ - ("WebTransaction", 1), - ("GraphQL/all", 1), - ("GraphQL/Strawberry/all", 1), - ("GraphQL/allWeb", 1), - ("GraphQL/Strawberry/allWeb", 1), - ] - _test_mutation_unscoped_metrics = _test_unscoped_metrics + _test_mutation_scoped_metrics - _test_query_unscoped_metrics = _test_unscoped_metrics + _test_query_scoped_metrics - - _expected_mutation_operation_attributes = { - "graphql.operation.type": "mutation", - "graphql.operation.name": "", - } - _expected_mutation_resolver_attributes = { - "graphql.field.name": "storage_add", - "graphql.field.parentType": "Mutation", - "graphql.field.path": "storage_add", - "graphql.field.returnType": "String!", - } - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "storage", - "graphql.field.parentType": "Query", - "graphql.field.path": "storage", - "graphql.field.returnType": "[String!]!", - } - - @validate_transaction_metrics( - "query//storage", - "GraphQL", - scoped_metrics=_test_query_scoped_metrics, - rollup_metrics=_test_query_unscoped_metrics + FRAMEWORK_METRICS, - ) - @validate_transaction_metrics( - "mutation//storage_add", - "GraphQL", - scoped_metrics=_test_mutation_scoped_metrics, - rollup_metrics=_test_mutation_unscoped_metrics + FRAMEWORK_METRICS, - index=-2, - ) - @validate_span_events(exact_agents=_expected_mutation_operation_attributes, index=-2) - @validate_span_events(exact_agents=_expected_mutation_resolver_attributes, index=-2) - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - def _test(): - response = graphql_asgi_run('mutation { storage_add(string: "abc") }') - assert response.status == 200 - response = json.loads(response.body.decode("utf-8")) - assert not response.get("errors") - - response = graphql_asgi_run("query { storage }") - assert response.status == 200 - response = json.loads(response.body.decode("utf-8")) - assert not response.get("errors") - - # These are separate assertions because pypy stores 'abc' as a unicode string while other Python versions do not - assert "storage" in str(response.get("data")) - assert "abc" in str(response.get("data")) - - _test() diff --git a/tests/logger_structlog/conftest.py b/tests/logger_structlog/conftest.py new file mode 100644 index 000000000..05a86d8a7 --- /dev/null +++ b/tests/logger_structlog/conftest.py @@ -0,0 +1,143 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import pytest +from structlog import DropEvent, PrintLogger +from newrelic.api.time_trace import current_trace +from newrelic.api.transaction import current_transaction +from testing_support.fixtures import ( + collector_agent_registration_fixture, + collector_available_fixture, +) + +_default_settings = { + "transaction_tracer.explain_threshold": 0.0, + "transaction_tracer.transaction_threshold": 0.0, + "transaction_tracer.stack_trace_threshold": 0.0, + "debug.log_data_collector_payloads": True, + "debug.record_transaction_failure": True, + "application_logging.enabled": True, + "application_logging.forwarding.enabled": True, + "application_logging.metrics.enabled": True, + "application_logging.local_decorating.enabled": True, + "event_harvest_config.harvest_limits.log_event_data": 100000, +} + +collector_agent_registration = collector_agent_registration_fixture( + app_name="Python Agent Test (logger_structlog)", + default_settings=_default_settings, +) + + +class StructLogCapLog(PrintLogger): + def __init__(self, caplog): + self.caplog = caplog if caplog is not None else [] + + def msg(self, event, **kwargs): + self.caplog.append(event) + return + + log = debug = info = warn = warning = msg + fatal = failure = err = error = critical = exception = msg + + def __repr__(self): + return "" % str(id(self)) + + __str__ = __repr__ + + +@pytest.fixture +def set_trace_ids(): + def _set(): + txn = current_transaction() + if txn: + txn._trace_id = "abcdefgh12345678" + trace = current_trace() + if trace: + trace.guid = "abcdefgh" + return _set + +def drop_event_processor(logger, method_name, event_dict): + if method_name == "info": + raise DropEvent + else: + return event_dict + + +@pytest.fixture(scope="function") +def structlog_caplog(): + return list() + + +@pytest.fixture(scope="function") +def logger(structlog_caplog): + import structlog + structlog.configure(processors=[], logger_factory=lambda *args, **kwargs: StructLogCapLog(structlog_caplog)) + _logger = structlog.get_logger() + return _logger + + +@pytest.fixture(scope="function") +def filtering_logger(structlog_caplog): + import structlog + structlog.configure(processors=[drop_event_processor], logger_factory=lambda *args, **kwargs: StructLogCapLog(structlog_caplog)) + _filtering_logger = structlog.get_logger() + return _filtering_logger + + +@pytest.fixture +def exercise_logging_multiple_lines(set_trace_ids, logger, structlog_caplog): + def _exercise(): + set_trace_ids() + + logger.msg("Cat", a=42) + logger.error("Dog") + logger.critical("Elephant") + + assert len(structlog_caplog) == 3 + + assert "Cat" in structlog_caplog[0] + assert "Dog" in structlog_caplog[1] + assert "Elephant" in structlog_caplog[2] + + return _exercise + + +@pytest.fixture +def exercise_filtering_logging_multiple_lines(set_trace_ids, filtering_logger, structlog_caplog): + def _exercise(): + set_trace_ids() + + filtering_logger.msg("Cat", a=42) + filtering_logger.error("Dog") + filtering_logger.critical("Elephant") + + assert len(structlog_caplog) == 2 + + assert "Cat" not in structlog_caplog[0] + assert "Dog" in structlog_caplog[0] + assert "Elephant" in structlog_caplog[1] + + return _exercise + + +@pytest.fixture +def exercise_logging_single_line(set_trace_ids, logger, structlog_caplog): + def _exercise(): + set_trace_ids() + logger.error("A", key="value") + assert len(structlog_caplog) == 1 + + return _exercise diff --git a/tests/logger_structlog/test_attribute_forwarding.py b/tests/logger_structlog/test_attribute_forwarding.py new file mode 100644 index 000000000..eb555cca1 --- /dev/null +++ b/tests/logger_structlog/test_attribute_forwarding.py @@ -0,0 +1,49 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from newrelic.api.background_task import background_task +from testing_support.fixtures import override_application_settings, reset_core_stats_engine +from testing_support.validators.validate_log_event_count import validate_log_event_count +from testing_support.validators.validate_log_event_count_outside_transaction import validate_log_event_count_outside_transaction +from testing_support.validators.validate_log_events import validate_log_events +from testing_support.validators.validate_log_events_outside_transaction import validate_log_events_outside_transaction + + +_event_attributes = {"message": "A"} + + +@override_application_settings({ + "application_logging.forwarding.context_data.enabled": True, +}) +def test_attributes_inside_transaction(exercise_logging_single_line): + @validate_log_events([_event_attributes]) + @validate_log_event_count(1) + @background_task() + def test(): + exercise_logging_single_line() + + test() + + +@reset_core_stats_engine() +@override_application_settings({ + "application_logging.forwarding.context_data.enabled": True, +}) +def test_attributes_outside_transaction(exercise_logging_single_line): + @validate_log_events_outside_transaction([_event_attributes]) + @validate_log_event_count_outside_transaction(1) + def test(): + exercise_logging_single_line() + + test() diff --git a/tests/logger_structlog/test_local_decorating.py b/tests/logger_structlog/test_local_decorating.py new file mode 100644 index 000000000..7b58d4a0c --- /dev/null +++ b/tests/logger_structlog/test_local_decorating.py @@ -0,0 +1,54 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform + +from newrelic.api.application import application_settings +from newrelic.api.background_task import background_task +from testing_support.fixtures import reset_core_stats_engine +from testing_support.validators.validate_log_event_count import validate_log_event_count +from testing_support.validators.validate_log_event_count_outside_transaction import validate_log_event_count_outside_transaction + + +def get_metadata_string(log_message, is_txn): + host = platform.uname()[1] + assert host + entity_guid = application_settings().entity_guid + if is_txn: + metadata_string = "".join(('NR-LINKING|', entity_guid, '|', host, '|abcdefgh12345678|abcdefgh|Python%20Agent%20Test%20%28logger_structlog%29|')) + else: + metadata_string = "".join(('NR-LINKING|', entity_guid, '|', host, '|||Python%20Agent%20Test%20%28logger_structlog%29|')) + formatted_string = log_message + " " + metadata_string + return formatted_string + + +@reset_core_stats_engine() +def test_local_log_decoration_inside_transaction(exercise_logging_single_line, structlog_caplog): + @validate_log_event_count(1) + @background_task() + def test(): + exercise_logging_single_line() + assert get_metadata_string('A', True) in structlog_caplog[0] + + test() + + +@reset_core_stats_engine() +def test_local_log_decoration_outside_transaction(exercise_logging_single_line, structlog_caplog): + @validate_log_event_count_outside_transaction(1) + def test(): + exercise_logging_single_line() + assert get_metadata_string('A', False) in structlog_caplog[0] + + test() diff --git a/tests/logger_structlog/test_log_forwarding.py b/tests/logger_structlog/test_log_forwarding.py new file mode 100644 index 000000000..e5a5e670f --- /dev/null +++ b/tests/logger_structlog/test_log_forwarding.py @@ -0,0 +1,88 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from newrelic.api.background_task import background_task +from testing_support.fixtures import override_application_settings, reset_core_stats_engine +from testing_support.validators.validate_log_event_count import validate_log_event_count +from testing_support.validators.validate_log_event_count_outside_transaction import \ + validate_log_event_count_outside_transaction +from testing_support.validators.validate_log_events import validate_log_events +from testing_support.validators.validate_log_events_outside_transaction import validate_log_events_outside_transaction + + +_common_attributes_service_linking = {"timestamp": None, "hostname": None, + "entity.name": "Python Agent Test (logger_structlog)", "entity.guid": None} + +_common_attributes_trace_linking = {"span.id": "abcdefgh", "trace.id": "abcdefgh12345678", + **_common_attributes_service_linking} + + +@reset_core_stats_engine() +@override_application_settings({"application_logging.local_decorating.enabled": False}) +def test_logging_inside_transaction(exercise_logging_multiple_lines): + @validate_log_events([ + {"message": "Cat", "level": "INFO", **_common_attributes_trace_linking}, + {"message": "Dog", "level": "ERROR", **_common_attributes_trace_linking}, + {"message": "Elephant", "level": "CRITICAL", **_common_attributes_trace_linking}, + ]) + @validate_log_event_count(3) + @background_task() + def test(): + exercise_logging_multiple_lines() + + test() + + +@reset_core_stats_engine() +@override_application_settings({"application_logging.local_decorating.enabled": False}) +def test_logging_filtering_inside_transaction(exercise_filtering_logging_multiple_lines): + @validate_log_events([ + {"message": "Dog", "level": "ERROR", **_common_attributes_trace_linking}, + {"message": "Elephant", "level": "CRITICAL", **_common_attributes_trace_linking}, + ]) + @validate_log_event_count(2) + @background_task() + def test(): + exercise_filtering_logging_multiple_lines() + + test() + + +@reset_core_stats_engine() +@override_application_settings({"application_logging.local_decorating.enabled": False}) +def test_logging_outside_transaction(exercise_logging_multiple_lines): + @validate_log_events_outside_transaction([ + {"message": "Cat", "level": "INFO", **_common_attributes_service_linking}, + {"message": "Dog", "level": "ERROR", **_common_attributes_service_linking}, + {"message": "Elephant", "level": "CRITICAL", **_common_attributes_service_linking}, + ]) + @validate_log_event_count_outside_transaction(3) + def test(): + exercise_logging_multiple_lines() + + test() + + +@reset_core_stats_engine() +@override_application_settings({"application_logging.local_decorating.enabled": False}) +def test_logging_filtering_outside_transaction(exercise_filtering_logging_multiple_lines): + @validate_log_events_outside_transaction([ + {"message": "Dog", "level": "ERROR", **_common_attributes_service_linking}, + {"message": "Elephant", "level": "CRITICAL", **_common_attributes_service_linking}, + ]) + @validate_log_event_count_outside_transaction(2) + def test(): + exercise_filtering_logging_multiple_lines() + + test() diff --git a/tests/logger_structlog/test_metrics.py b/tests/logger_structlog/test_metrics.py new file mode 100644 index 000000000..48f7204e8 --- /dev/null +++ b/tests/logger_structlog/test_metrics.py @@ -0,0 +1,73 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from newrelic.packages import six +from newrelic.api.background_task import background_task +from testing_support.fixtures import reset_core_stats_engine +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from testing_support.validators.validate_custom_metrics_outside_transaction import validate_custom_metrics_outside_transaction + + +_test_logging_unscoped_metrics = [ + ("Logging/lines", 3), + ("Logging/lines/INFO", 1), + ("Logging/lines/ERROR", 1), + ("Logging/lines/CRITICAL", 1), +] + + +@reset_core_stats_engine() +def test_logging_metrics_inside_transaction(exercise_logging_multiple_lines): + txn_name = "test_metrics:test_logging_metrics_inside_transaction..test" if six.PY3 else "test_metrics:test" + @validate_transaction_metrics( + txn_name, + custom_metrics=_test_logging_unscoped_metrics, + background_task=True, + ) + @background_task() + def test(): + exercise_logging_multiple_lines() + + test() + + +@reset_core_stats_engine() +def test_logging_metrics_outside_transaction(exercise_logging_multiple_lines): + @validate_custom_metrics_outside_transaction(_test_logging_unscoped_metrics) + def test(): + exercise_logging_multiple_lines() + + test() + + +_test_logging_unscoped_filtering_metrics = [ + ("Logging/lines", 2), + ("Logging/lines/ERROR", 1), + ("Logging/lines/CRITICAL", 1), +] + + +@reset_core_stats_engine() +def test_filtering_logging_metrics_inside_transaction(exercise_filtering_logging_multiple_lines): + txn_name = "test_metrics:test_filtering_logging_metrics_inside_transaction..test" if six.PY3 else "test_metrics:test" + @validate_transaction_metrics( + txn_name, + custom_metrics=_test_logging_unscoped_filtering_metrics, + background_task=True, + ) + @background_task() + def test(): + exercise_filtering_logging_multiple_lines() + + test() diff --git a/tests/messagebroker_pika/test_pika_async_connection_consume.py b/tests/messagebroker_pika/test_pika_async_connection_consume.py index 4e44c7ed7..29b9d8ea4 100644 --- a/tests/messagebroker_pika/test_pika_async_connection_consume.py +++ b/tests/messagebroker_pika/test_pika_async_connection_consume.py @@ -49,20 +49,20 @@ from newrelic.api.background_task import background_task + DB_SETTINGS = rabbitmq_settings()[0] _message_broker_tt_params = { - "queue_name": QUEUE, - "routing_key": QUEUE, - "correlation_id": CORRELATION_ID, - "reply_to": REPLY_TO, - "headers": HEADERS.copy(), + 'queue_name': QUEUE, + 'routing_key': QUEUE, + 'correlation_id': CORRELATION_ID, + 'reply_to': REPLY_TO, + 'headers': HEADERS.copy(), } # Tornado's IO loop is not configurable in versions 5.x and up try: - class MyIOLoop(tornado.ioloop.IOLoop.configured_class()): def handle_callback_exception(self, *args, **kwargs): raise @@ -73,44 +73,38 @@ def handle_callback_exception(self, *args, **kwargs): connection_classes = [pika.SelectConnection, TornadoConnection] -parametrized_connection = pytest.mark.parametrize("ConnectionClass", connection_classes) +parametrized_connection = pytest.mark.parametrize('ConnectionClass', + connection_classes) _test_select_conn_basic_get_inside_txn_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, 1), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, 1), ] if six.PY3: _test_select_conn_basic_get_inside_txn_metrics.append( - ( - ( - "Function/test_pika_async_connection_consume:" - "test_async_connection_basic_get_inside_txn." - ".on_message" - ), - 1, - ) - ) + (('Function/test_pika_async_connection_consume:' + 'test_async_connection_basic_get_inside_txn.' + '.on_message'), 1)) else: - _test_select_conn_basic_get_inside_txn_metrics.append(("Function/test_pika_async_connection_consume:on_message", 1)) + _test_select_conn_basic_get_inside_txn_metrics.append( + ('Function/test_pika_async_connection_consume:on_message', 1)) @parametrized_connection -@pytest.mark.parametrize("callback_as_partial", [True, False]) -@validate_code_level_metrics( - "test_pika_async_connection_consume" + (".test_async_connection_basic_get_inside_txn." if six.PY3 else ""), - "on_message", -) +@pytest.mark.parametrize('callback_as_partial', [True, False]) +@validate_code_level_metrics("test_pika_async_connection_consume.test_async_connection_basic_get_inside_txn.", "on_message", py2_namespace="test_pika_async_connection_consume") @validate_transaction_metrics( - ("test_pika_async_connection_consume:" "test_async_connection_basic_get_inside_txn"), - scoped_metrics=_test_select_conn_basic_get_inside_txn_metrics, - rollup_metrics=_test_select_conn_basic_get_inside_txn_metrics, - background_task=True, -) + ('test_pika_async_connection_consume:' + 'test_async_connection_basic_get_inside_txn'), + scoped_metrics=_test_select_conn_basic_get_inside_txn_metrics, + rollup_metrics=_test_select_conn_basic_get_inside_txn_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() -def test_async_connection_basic_get_inside_txn(producer, ConnectionClass, callback_as_partial): +def test_async_connection_basic_get_inside_txn(producer, ConnectionClass, + callback_as_partial): def on_message(channel, method_frame, header_frame, body): assert method_frame assert body == BODY @@ -128,7 +122,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = ConnectionClass(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = ConnectionClass( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -139,8 +135,9 @@ def on_open_connection(connection): @parametrized_connection -@pytest.mark.parametrize("callback_as_partial", [True, False]) -def test_select_connection_basic_get_outside_txn(producer, ConnectionClass, callback_as_partial): +@pytest.mark.parametrize('callback_as_partial', [True, False]) +def test_select_connection_basic_get_outside_txn(producer, ConnectionClass, + callback_as_partial): metrics_list = [] @capture_transaction_metrics(metrics_list) @@ -163,8 +160,8 @@ def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) connection = ConnectionClass( - pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection - ) + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -181,24 +178,25 @@ def on_open_connection(connection): _test_select_conn_basic_get_inside_txn_no_callback_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] @pytest.mark.skipif( - condition=pika_version_info[0] > 0, reason="pika 1.0 removed the ability to use basic_get with callback=None" -) + condition=pika_version_info[0] > 0, + reason='pika 1.0 removed the ability to use basic_get with callback=None') @parametrized_connection @validate_transaction_metrics( - ("test_pika_async_connection_consume:" "test_async_connection_basic_get_inside_txn_no_callback"), + ('test_pika_async_connection_consume:' + 'test_async_connection_basic_get_inside_txn_no_callback'), scoped_metrics=_test_select_conn_basic_get_inside_txn_no_callback_metrics, rollup_metrics=_test_select_conn_basic_get_inside_txn_no_callback_metrics, - background_task=True, -) + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() -def test_async_connection_basic_get_inside_txn_no_callback(producer, ConnectionClass): +def test_async_connection_basic_get_inside_txn_no_callback(producer, + ConnectionClass): def on_open_channel(channel): channel.basic_get(callback=None, queue=QUEUE) channel.close() @@ -208,7 +206,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = ConnectionClass(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = ConnectionClass( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -219,26 +219,27 @@ def on_open_connection(connection): _test_async_connection_basic_get_empty_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] @parametrized_connection -@pytest.mark.parametrize("callback_as_partial", [True, False]) +@pytest.mark.parametrize('callback_as_partial', [True, False]) @validate_transaction_metrics( - ("test_pika_async_connection_consume:" "test_async_connection_basic_get_empty"), - scoped_metrics=_test_async_connection_basic_get_empty_metrics, - rollup_metrics=_test_async_connection_basic_get_empty_metrics, - background_task=True, -) + ('test_pika_async_connection_consume:' + 'test_async_connection_basic_get_empty'), + scoped_metrics=_test_async_connection_basic_get_empty_metrics, + rollup_metrics=_test_async_connection_basic_get_empty_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() -def test_async_connection_basic_get_empty(ConnectionClass, callback_as_partial): - QUEUE = "test_async_empty" +def test_async_connection_basic_get_empty(ConnectionClass, + callback_as_partial): + QUEUE = 'test_async_empty' def on_message(channel, method_frame, header_frame, body): - assert False, body.decode("UTF-8") + assert False, body.decode('UTF-8') if callback_as_partial: on_message = functools.partial(on_message) @@ -252,7 +253,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = ConnectionClass(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = ConnectionClass( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -263,42 +266,33 @@ def on_open_connection(connection): _test_select_conn_basic_consume_in_txn_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] if six.PY3: _test_select_conn_basic_consume_in_txn_metrics.append( - ( - ( - "Function/test_pika_async_connection_consume:" - "test_async_connection_basic_consume_inside_txn." - ".on_message" - ), - 1, - ) - ) + (('Function/test_pika_async_connection_consume:' + 'test_async_connection_basic_consume_inside_txn.' + '.on_message'), 1)) else: - _test_select_conn_basic_consume_in_txn_metrics.append(("Function/test_pika_async_connection_consume:on_message", 1)) + _test_select_conn_basic_consume_in_txn_metrics.append( + ('Function/test_pika_async_connection_consume:on_message', 1)) @parametrized_connection @validate_transaction_metrics( - ("test_pika_async_connection_consume:" "test_async_connection_basic_consume_inside_txn"), - scoped_metrics=_test_select_conn_basic_consume_in_txn_metrics, - rollup_metrics=_test_select_conn_basic_consume_in_txn_metrics, - background_task=True, -) -@validate_code_level_metrics( - "test_pika_async_connection_consume" - + (".test_async_connection_basic_consume_inside_txn." if six.PY3 else ""), - "on_message", -) + ('test_pika_async_connection_consume:' + 'test_async_connection_basic_consume_inside_txn'), + scoped_metrics=_test_select_conn_basic_consume_in_txn_metrics, + rollup_metrics=_test_select_conn_basic_consume_in_txn_metrics, + background_task=True) +@validate_code_level_metrics("test_pika_async_connection_consume.test_async_connection_basic_consume_inside_txn.", "on_message", py2_namespace="test_pika_async_connection_consume") @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() def test_async_connection_basic_consume_inside_txn(producer, ConnectionClass): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.basic_ack(method_frame.delivery_tag) channel.close() @@ -311,7 +305,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = ConnectionClass(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = ConnectionClass( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -322,67 +318,46 @@ def on_open_connection(connection): _test_select_conn_basic_consume_two_exchanges = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE_2, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE_2, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE_2, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE_2, None), ] if six.PY3: _test_select_conn_basic_consume_two_exchanges.append( - ( - ( - "Function/test_pika_async_connection_consume:" - "test_async_connection_basic_consume_two_exchanges." - ".on_message_1" - ), - 1, - ) - ) + (('Function/test_pika_async_connection_consume:' + 'test_async_connection_basic_consume_two_exchanges.' + '.on_message_1'), 1)) _test_select_conn_basic_consume_two_exchanges.append( - ( - ( - "Function/test_pika_async_connection_consume:" - "test_async_connection_basic_consume_two_exchanges." - ".on_message_2" - ), - 1, - ) - ) + (('Function/test_pika_async_connection_consume:' + 'test_async_connection_basic_consume_two_exchanges.' + '.on_message_2'), 1)) else: _test_select_conn_basic_consume_two_exchanges.append( - ("Function/test_pika_async_connection_consume:on_message_1", 1) - ) + ('Function/test_pika_async_connection_consume:on_message_1', 1)) _test_select_conn_basic_consume_two_exchanges.append( - ("Function/test_pika_async_connection_consume:on_message_2", 1) - ) + ('Function/test_pika_async_connection_consume:on_message_2', 1)) @parametrized_connection @validate_transaction_metrics( - ("test_pika_async_connection_consume:" "test_async_connection_basic_consume_two_exchanges"), - scoped_metrics=_test_select_conn_basic_consume_two_exchanges, - rollup_metrics=_test_select_conn_basic_consume_two_exchanges, - background_task=True, -) -@validate_code_level_metrics( - "test_pika_async_connection_consume" - + (".test_async_connection_basic_consume_two_exchanges." if six.PY3 else ""), - "on_message_1", -) -@validate_code_level_metrics( - "test_pika_async_connection_consume" - + (".test_async_connection_basic_consume_two_exchanges." if six.PY3 else ""), - "on_message_2", -) + ('test_pika_async_connection_consume:' + 'test_async_connection_basic_consume_two_exchanges'), + scoped_metrics=_test_select_conn_basic_consume_two_exchanges, + rollup_metrics=_test_select_conn_basic_consume_two_exchanges, + background_task=True) +@validate_code_level_metrics("test_pika_async_connection_consume.test_async_connection_basic_consume_two_exchanges.", "on_message_1", py2_namespace="test_pika_async_connection_consume") +@validate_code_level_metrics("test_pika_async_connection_consume.test_async_connection_basic_consume_two_exchanges.", "on_message_2", py2_namespace="test_pika_async_connection_consume") @background_task() -def test_async_connection_basic_consume_two_exchanges(producer, producer_2, ConnectionClass): +def test_async_connection_basic_consume_two_exchanges(producer, producer_2, + ConnectionClass): global events_received events_received = 0 def on_message_1(channel, method_frame, header_frame, body): channel.basic_ack(method_frame.delivery_tag) - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY global events_received @@ -395,7 +370,7 @@ def on_message_1(channel, method_frame, header_frame, body): def on_message_2(channel, method_frame, header_frame, body): channel.basic_ack(method_frame.delivery_tag) - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY global events_received @@ -413,7 +388,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = ConnectionClass(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = ConnectionClass( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -424,11 +401,12 @@ def on_open_connection(connection): # This should not create a transaction -@function_not_called("newrelic.core.stats_engine", "StatsEngine.record_transaction") -@override_application_settings({"debug.record_transaction_failure": True}) +@function_not_called('newrelic.core.stats_engine', + 'StatsEngine.record_transaction') +@override_application_settings({'debug.record_transaction_failure': True}) def test_tornado_connection_basic_consume_outside_transaction(producer): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.basic_ack(method_frame.delivery_tag) channel.close() @@ -441,7 +419,9 @@ def on_open_channel(channel): def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) - connection = TornadoConnection(pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection) + connection = TornadoConnection( + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() @@ -452,44 +432,31 @@ def on_open_connection(connection): if six.PY3: - _txn_name = ( - "test_pika_async_connection_consume:" - "test_select_connection_basic_consume_outside_transaction." - ".on_message" - ) + _txn_name = ('test_pika_async_connection_consume:' + 'test_select_connection_basic_consume_outside_transaction.' + '.on_message') _test_select_connection_consume_outside_txn_metrics = [ - ( - ( - "Function/test_pika_async_connection_consume:" - "test_select_connection_basic_consume_outside_transaction." - ".on_message" - ), - None, - ) - ] + (('Function/test_pika_async_connection_consume:' + 'test_select_connection_basic_consume_outside_transaction.' + '.on_message'), None)] else: - _txn_name = "test_pika_async_connection_consume:on_message" + _txn_name = ( + 'test_pika_async_connection_consume:on_message') _test_select_connection_consume_outside_txn_metrics = [ - ("Function/test_pika_async_connection_consume:on_message", None) - ] + ('Function/test_pika_async_connection_consume:on_message', None)] # This should create a transaction @validate_transaction_metrics( - _txn_name, - scoped_metrics=_test_select_connection_consume_outside_txn_metrics, - rollup_metrics=_test_select_connection_consume_outside_txn_metrics, - background_task=True, - group="Message/RabbitMQ/Exchange/%s" % EXCHANGE, -) -@validate_code_level_metrics( - "test_pika_async_connection_consume" - + (".test_select_connection_basic_consume_outside_transaction." if six.PY3 else ""), - "on_message", -) + _txn_name, + scoped_metrics=_test_select_connection_consume_outside_txn_metrics, + rollup_metrics=_test_select_connection_consume_outside_txn_metrics, + background_task=True, + group='Message/RabbitMQ/Exchange/%s' % EXCHANGE) +@validate_code_level_metrics("test_pika_async_connection_consume.test_select_connection_basic_consume_outside_transaction.", "on_message", py2_namespace="test_pika_async_connection_consume") def test_select_connection_basic_consume_outside_transaction(producer): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.basic_ack(method_frame.delivery_tag) channel.close() @@ -503,8 +470,8 @@ def on_open_connection(connection): connection.channel(on_open_callback=on_open_channel) connection = pika.SelectConnection( - pika.ConnectionParameters(DB_SETTINGS["host"]), on_open_callback=on_open_connection - ) + pika.ConnectionParameters(DB_SETTINGS['host']), + on_open_callback=on_open_connection) try: connection.ioloop.start() diff --git a/tests/messagebroker_pika/test_pika_blocking_connection_consume.py b/tests/messagebroker_pika/test_pika_blocking_connection_consume.py index 7b41674a2..e097cfbe9 100644 --- a/tests/messagebroker_pika/test_pika_blocking_connection_consume.py +++ b/tests/messagebroker_pika/test_pika_blocking_connection_consume.py @@ -38,30 +38,32 @@ DB_SETTINGS = rabbitmq_settings()[0] _message_broker_tt_params = { - "queue_name": QUEUE, - "routing_key": QUEUE, - "correlation_id": CORRELATION_ID, - "reply_to": REPLY_TO, - "headers": HEADERS.copy(), + 'queue_name': QUEUE, + 'routing_key': QUEUE, + 'correlation_id': CORRELATION_ID, + 'reply_to': REPLY_TO, + 'headers': HEADERS.copy(), } _test_blocking_connection_basic_get_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, 1), - (("Function/pika.adapters.blocking_connection:" "_CallbackResult.set_value_once"), 1), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, 1), + (('Function/pika.adapters.blocking_connection:' + '_CallbackResult.set_value_once'), 1) ] @validate_transaction_metrics( - ("test_pika_blocking_connection_consume:" "test_blocking_connection_basic_get"), - scoped_metrics=_test_blocking_connection_basic_get_metrics, - rollup_metrics=_test_blocking_connection_basic_get_metrics, - background_task=True, -) + ('test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_get'), + scoped_metrics=_test_blocking_connection_basic_get_metrics, + rollup_metrics=_test_blocking_connection_basic_get_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() def test_blocking_connection_basic_get(producer): - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() method_frame, _, _ = channel.basic_get(QUEUE) assert method_frame @@ -69,22 +71,23 @@ def test_blocking_connection_basic_get(producer): _test_blocking_connection_basic_get_empty_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] @validate_transaction_metrics( - ("test_pika_blocking_connection_consume:" "test_blocking_connection_basic_get_empty"), - scoped_metrics=_test_blocking_connection_basic_get_empty_metrics, - rollup_metrics=_test_blocking_connection_basic_get_empty_metrics, - background_task=True, -) + ('test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_get_empty'), + scoped_metrics=_test_blocking_connection_basic_get_empty_metrics, + rollup_metrics=_test_blocking_connection_basic_get_empty_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() def test_blocking_connection_basic_get_empty(): - QUEUE = "test_blocking_empty-%s" % os.getpid() - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + QUEUE = 'test_blocking_empty-%s' % os.getpid() + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() channel.queue_declare(queue=QUEUE) @@ -100,7 +103,8 @@ def test_blocking_connection_basic_get_outside_transaction(producer): @capture_transaction_metrics(metrics_list) def test_basic_get(): - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() channel.queue_declare(queue=QUEUE) @@ -116,57 +120,46 @@ def test_basic_get(): _test_blocking_conn_basic_consume_no_txn_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] if six.PY3: - _txn_name = ( - "test_pika_blocking_connection_consume:" - "test_blocking_connection_basic_consume_outside_transaction." - ".on_message" - ) + _txn_name = ('test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_outside_transaction.' + '.on_message') _test_blocking_conn_basic_consume_no_txn_metrics.append( - ( - ( - "Function/test_pika_blocking_connection_consume:" - "test_blocking_connection_basic_consume_outside_transaction." - ".on_message" - ), - None, - ) - ) + (('Function/test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_outside_transaction.' + '.on_message'), None)) else: - _txn_name = "test_pika_blocking_connection_consume:" "on_message" + _txn_name = ('test_pika_blocking_connection_consume:' + 'on_message') _test_blocking_conn_basic_consume_no_txn_metrics.append( - ("Function/test_pika_blocking_connection_consume:on_message", None) - ) + ('Function/test_pika_blocking_connection_consume:on_message', None)) -@pytest.mark.parametrize("as_partial", [True, False]) -@validate_code_level_metrics( - "test_pika_blocking_connection_consume" - + (".test_blocking_connection_basic_consume_outside_transaction." if six.PY3 else ""), - "on_message", -) +@pytest.mark.parametrize('as_partial', [True, False]) +@validate_code_level_metrics("test_pika_blocking_connection_consume.test_blocking_connection_basic_consume_outside_transaction.", "on_message", py2_namespace="test_pika_blocking_connection_consume") @validate_transaction_metrics( - _txn_name, - scoped_metrics=_test_blocking_conn_basic_consume_no_txn_metrics, - rollup_metrics=_test_blocking_conn_basic_consume_no_txn_metrics, - background_task=True, - group="Message/RabbitMQ/Exchange/%s" % EXCHANGE, -) + _txn_name, + scoped_metrics=_test_blocking_conn_basic_consume_no_txn_metrics, + rollup_metrics=_test_blocking_conn_basic_consume_no_txn_metrics, + background_task=True, + group='Message/RabbitMQ/Exchange/%s' % EXCHANGE) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) -def test_blocking_connection_basic_consume_outside_transaction(producer, as_partial): +def test_blocking_connection_basic_consume_outside_transaction(producer, + as_partial): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.stop_consuming() if as_partial: on_message = functools.partial(on_message) - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() basic_consume(channel, QUEUE, on_message) @@ -178,51 +171,41 @@ def on_message(channel, method_frame, header_frame, body): _test_blocking_conn_basic_consume_in_txn_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), ] if six.PY3: _test_blocking_conn_basic_consume_in_txn_metrics.append( - ( - ( - "Function/test_pika_blocking_connection_consume:" - "test_blocking_connection_basic_consume_inside_txn." - ".on_message" - ), - 1, - ) - ) + (('Function/test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_inside_txn.' + '.on_message'), 1)) else: _test_blocking_conn_basic_consume_in_txn_metrics.append( - ("Function/test_pika_blocking_connection_consume:on_message", 1) - ) + ('Function/test_pika_blocking_connection_consume:on_message', 1)) -@pytest.mark.parametrize("as_partial", [True, False]) -@validate_code_level_metrics( - "test_pika_blocking_connection_consume" - + (".test_blocking_connection_basic_consume_inside_txn." if six.PY3 else ""), - "on_message", -) +@pytest.mark.parametrize('as_partial', [True, False]) +@validate_code_level_metrics("test_pika_blocking_connection_consume.test_blocking_connection_basic_consume_inside_txn.", "on_message", py2_namespace="test_pika_blocking_connection_consume") @validate_transaction_metrics( - ("test_pika_blocking_connection_consume:" "test_blocking_connection_basic_consume_inside_txn"), - scoped_metrics=_test_blocking_conn_basic_consume_in_txn_metrics, - rollup_metrics=_test_blocking_conn_basic_consume_in_txn_metrics, - background_task=True, -) + ('test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_inside_txn'), + scoped_metrics=_test_blocking_conn_basic_consume_in_txn_metrics, + rollup_metrics=_test_blocking_conn_basic_consume_in_txn_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() def test_blocking_connection_basic_consume_inside_txn(producer, as_partial): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.stop_consuming() if as_partial: on_message = functools.partial(on_message) - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() basic_consume(channel, QUEUE, on_message) try: @@ -233,40 +216,33 @@ def on_message(channel, method_frame, header_frame, body): _test_blocking_conn_basic_consume_stopped_txn_metrics = [ - ("MessageBroker/RabbitMQ/Exchange/Produce/Named/%s" % EXCHANGE, None), - ("MessageBroker/RabbitMQ/Exchange/Consume/Named/%s" % EXCHANGE, None), - ("OtherTransaction/Message/RabbitMQ/Exchange/Named/%s" % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Produce/Named/%s' % EXCHANGE, None), + ('MessageBroker/RabbitMQ/Exchange/Consume/Named/%s' % EXCHANGE, None), + ('OtherTransaction/Message/RabbitMQ/Exchange/Named/%s' % EXCHANGE, None), ] if six.PY3: _test_blocking_conn_basic_consume_stopped_txn_metrics.append( - ( - ( - "Function/test_pika_blocking_connection_consume:" - "test_blocking_connection_basic_consume_stopped_txn." - ".on_message" - ), - None, - ) - ) + (('Function/test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_stopped_txn.' + '.on_message'), None)) else: _test_blocking_conn_basic_consume_stopped_txn_metrics.append( - ("Function/test_pika_blocking_connection_consume:on_message", None) - ) + ('Function/test_pika_blocking_connection_consume:on_message', None)) -@pytest.mark.parametrize("as_partial", [True, False]) +@pytest.mark.parametrize('as_partial', [True, False]) @validate_transaction_metrics( - ("test_pika_blocking_connection_consume:" "test_blocking_connection_basic_consume_stopped_txn"), - scoped_metrics=_test_blocking_conn_basic_consume_stopped_txn_metrics, - rollup_metrics=_test_blocking_conn_basic_consume_stopped_txn_metrics, - background_task=True, -) + ('test_pika_blocking_connection_consume:' + 'test_blocking_connection_basic_consume_stopped_txn'), + scoped_metrics=_test_blocking_conn_basic_consume_stopped_txn_metrics, + rollup_metrics=_test_blocking_conn_basic_consume_stopped_txn_metrics, + background_task=True) @validate_tt_collector_json(message_broker_params=_message_broker_tt_params) @background_task() def test_blocking_connection_basic_consume_stopped_txn(producer, as_partial): def on_message(channel, method_frame, header_frame, body): - assert hasattr(method_frame, "_nr_start_time") + assert hasattr(method_frame, '_nr_start_time') assert body == BODY channel.stop_consuming() @@ -275,7 +251,8 @@ def on_message(channel, method_frame, header_frame, body): if as_partial: on_message = functools.partial(on_message) - with pika.BlockingConnection(pika.ConnectionParameters(DB_SETTINGS["host"])) as connection: + with pika.BlockingConnection( + pika.ConnectionParameters(DB_SETTINGS['host'])) as connection: channel = connection.channel() basic_consume(channel, QUEUE, on_message) try: diff --git a/tests/testing_support/db_settings.py b/tests/testing_support/db_settings.py index b095c0912..f7bda3d7a 100644 --- a/tests/testing_support/db_settings.py +++ b/tests/testing_support/db_settings.py @@ -121,6 +121,31 @@ def redis_settings(): return settings +def redis_cluster_settings(): + """Return a list of dict of settings for connecting to redis cluster. + + Will return the correct settings, depending on which of the environments it + is running in. It attempts to set variables in the following order, where + later environments override earlier ones. + + 1. Local + 2. Github Actions + """ + + host = "host.docker.internal" if "GITHUB_ACTIONS" in os.environ else "localhost" + instances = 1 + base_port = 6379 + + settings = [ + { + "host": host, + "port": base_port + instance_num, + } + for instance_num in range(instances) + ] + return settings + + def memcached_settings(): """Return a list of dict of settings for connecting to memcached. @@ -165,6 +190,28 @@ def mongodb_settings(): return settings +def firestore_settings(): + """Return a list of dict of settings for connecting to firestore. + + This only includes the host and port as the collection name is defined in + the firestore conftest file. + Will return the correct settings, depending on which of the environments it + is running in. It attempts to set variables in the following order, where + later environments override earlier ones. + + 1. Local + 2. Github Actions + """ + + host = "host.docker.internal" if "GITHUB_ACTIONS" in os.environ else "127.0.0.1" + instances = 2 + settings = [ + {"host": host, "port": 8080 + instance_num} + for instance_num in range(instances) + ] + return settings + + def elasticsearch_settings(): """Return a list of dict of settings for connecting to elasticsearch. diff --git a/tests/testing_support/validators/validate_code_level_metrics.py b/tests/testing_support/validators/validate_code_level_metrics.py index d5c4b5648..c3a880b35 100644 --- a/tests/testing_support/validators/validate_code_level_metrics.py +++ b/tests/testing_support/validators/validate_code_level_metrics.py @@ -12,13 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from newrelic.packages import six from testing_support.validators.validate_span_events import validate_span_events from testing_support.fixtures import dt_enabled from newrelic.common.object_wrapper import function_wrapper -def validate_code_level_metrics(namespace, function, builtin=False, count=1, index=-1): + +def validate_code_level_metrics(namespace, function, py2_namespace=None, builtin=False, count=1, index=-1): """Verify that code level metrics are generated for a callable.""" + if six.PY2 and py2_namespace is not None: + namespace = py2_namespace + if builtin: validator = validate_span_events( exact_agents={"code.function": function, "code.namespace": namespace, "code.filepath": ""}, @@ -38,5 +43,4 @@ def validate_code_level_metrics(namespace, function, builtin=False, count=1, ind def wrapper(wrapped, instance, args, kwargs): validator(dt_enabled(wrapped))(*args, **kwargs) - return wrapper - + return wrapper \ No newline at end of file diff --git a/tests/testing_support/validators/validate_datastore_trace_inputs.py b/tests/testing_support/validators/validate_datastore_trace_inputs.py index ade4ebea6..365a14ebd 100644 --- a/tests/testing_support/validators/validate_datastore_trace_inputs.py +++ b/tests/testing_support/validators/validate_datastore_trace_inputs.py @@ -23,7 +23,7 @@ """ -def validate_datastore_trace_inputs(operation=None, target=None): +def validate_datastore_trace_inputs(operation=None, target=None, host=None, port_path_or_id=None, database_name=None): @transient_function_wrapper("newrelic.api.datastore_trace", "DatastoreTrace.__init__") @catch_background_exceptions def _validate_datastore_trace_inputs(wrapped, instance, args, kwargs): @@ -44,6 +44,18 @@ def _bind_params(product, target, operation, host=None, port_path_or_id=None, da assert captured_target == target, "%s didn't match expected %s" % (captured_target, target) if operation is not None: assert captured_operation == operation, "%s didn't match expected %s" % (captured_operation, operation) + if host is not None: + assert captured_host == host, "%s didn't match expected %s" % (captured_host, host) + if port_path_or_id is not None: + assert captured_port_path_or_id == port_path_or_id, "%s didn't match expected %s" % ( + captured_port_path_or_id, + port_path_or_id, + ) + if database_name is not None: + assert captured_database_name == database_name, "%s didn't match expected %s" % ( + captured_database_name, + database_name, + ) return wrapped(*args, **kwargs) diff --git a/tox.ini b/tox.ini index bd9ad936e..f1fcf833c 100644 --- a/tox.ini +++ b/tox.ini @@ -85,6 +85,7 @@ envlist = memcached-datastore_memcache-{py27,py37,py38,py39,py310,py311,pypy27,pypy38}-memcached01, mysql-datastore_mysql-mysql080023-py27, mysql-datastore_mysql-mysqllatest-{py37,py38,py39,py310,py311}, + firestore-datastore_firestore-{py37,py38,py39,py310,py311}, postgres-datastore_postgresql-{py37,py38,py39}, postgres-datastore_psycopg2-{py27,py37,py38,py39,py310,py311}-psycopg2latest postgres-datastore_psycopg2cffi-{py27,pypy27,py37,py38,py39,py310,py311}-psycopg2cffilatest, @@ -129,8 +130,6 @@ envlist = # Falcon master branch failing on 3.11 currently. python-framework_falcon-py311-falcon0200, python-framework_fastapi-{py37,py38,py39,py310,py311}, - python-framework_flask-{pypy27,py27}-flask0012, - python-framework_flask-{pypy27,py27,py37,py38,py39,py310,py311,pypy38}-flask0101, ; temporarily disabling flaskmaster tests python-framework_flask-{py37,py38,py39,py310,py311,pypy38}-flasklatest, python-framework_graphene-{py37,py38,py39,py310,py311}-graphenelatest, @@ -151,6 +150,7 @@ envlist = python-logger_logging-{py27,py37,py38,py39,py310,py311,pypy27,pypy38}, python-logger_loguru-{py37,py38,py39,py310,py311,pypy38}-logurulatest, python-logger_loguru-py39-loguru{06,05,04,03}, + python-logger_structlog-{py37,py38,py39,py310,py311,pypy38}-structloglatest, python-framework_tornado-{py38,py39,py310,py311}-tornadolatest, python-framework_tornado-{py38,py39,py310,py311}-tornadomaster, rabbitmq-messagebroker_pika-{py27,py37,py38,py39,pypy27,pypy38}-pika0.13, @@ -242,6 +242,7 @@ deps = datastore_elasticsearch: requests datastore_elasticsearch-elasticsearch07: elasticsearch<8.0 datastore_elasticsearch-elasticsearch08: elasticsearch<9.0 + datastore_firestore: google-cloud-firestore datastore_memcache-memcached01: python-memcached<2 datastore_mysql-mysqllatest: mysql-connector-python datastore_mysql-mysql080023: mysql-connector-python<8.0.24 @@ -309,8 +310,6 @@ deps = framework_flask: markupsafe<2.1 framework_flask: jinja2<3.1 framework_flask: Flask-Compress - framework_flask-flask0012: flask<0.13 - framework_flask-flask0101: flask<1.2 framework_flask-flasklatest: flask[async] framework_flask-flaskmaster: https://github.com/pallets/werkzeug/archive/main.zip framework_flask-flaskmaster: https://github.com/pallets/flask/archive/main.zip#egg=flask[async] @@ -347,7 +346,6 @@ deps = framework_sanic-saniclatest: sanic framework_sanic-sanic{1812,190301,1906}: aiohttp framework_sanic-sanic{1812,190301,1906,1912,200904,210300,2109,2112,2203,2290}: websockets<11 - framework_starlette: graphene<3 ; For test_exception_in_middleware test, anyio is used: ; https://github.com/encode/starlette/pull/1157 ; but anyiolatest creates breaking changes to our tests @@ -370,6 +368,7 @@ deps = logger_loguru-loguru05: loguru<0.6 logger_loguru-loguru04: loguru<0.5 logger_loguru-loguru03: loguru<0.4 + logger_structlog-structloglatest: structlog messagebroker_pika-pika0.13: pika<0.14 messagebroker_pika-pikalatest: pika messagebroker_pika: tornado<5 @@ -446,6 +445,7 @@ changedir = datastore_asyncpg: tests/datastore_asyncpg datastore_bmemcached: tests/datastore_bmemcached datastore_elasticsearch: tests/datastore_elasticsearch + datastore_firestore: tests/datastore_firestore datastore_memcache: tests/datastore_memcache datastore_mysql: tests/datastore_mysql datastore_postgresql: tests/datastore_postgresql @@ -488,6 +488,7 @@ changedir = framework_tornado: tests/framework_tornado logger_logging: tests/logger_logging logger_loguru: tests/logger_loguru + logger_structlog: tests/logger_structlog messagebroker_pika: tests/messagebroker_pika messagebroker_confluentkafka: tests/messagebroker_confluentkafka messagebroker_kafkapython: tests/messagebroker_kafkapython