Skip to content

Commit

Permalink
feat(ai-monitoring): Cohere integration (#3055)
Browse files Browse the repository at this point in the history
* Cohere integration

* Fix lint

* Fix bug with model ID not being pulled

* Exclude known models from langchain

* tox.ini

* Removed print statement

* Apply suggestions from code review

Co-authored-by: Anton Pirker <anton.pirker@sentry.io>

---------

Co-authored-by: Anton Pirker <anton.pirker@sentry.io>
  • Loading branch information
colin-sentry and antonpirker committed May 10, 2024
1 parent 1a32183 commit 40746ef
Show file tree
Hide file tree
Showing 10 changed files with 523 additions and 2 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/test-integrations-data-processing.yml
Expand Up @@ -58,6 +58,10 @@ jobs:
run: |
set -x # print commands that are executed
./scripts/runtox.sh "py${{ matrix.python-version }}-celery-latest" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch
- name: Test cohere latest
run: |
set -x # print commands that are executed
./scripts/runtox.sh "py${{ matrix.python-version }}-cohere-latest" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch
- name: Test huey latest
run: |
set -x # print commands that are executed
Expand Down Expand Up @@ -126,6 +130,10 @@ jobs:
run: |
set -x # print commands that are executed
./scripts/runtox.sh --exclude-latest "py${{ matrix.python-version }}-celery" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch
- name: Test cohere pinned
run: |
set -x # print commands that are executed
./scripts/runtox.sh --exclude-latest "py${{ matrix.python-version }}-cohere" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch
- name: Test huey pinned
run: |
set -x # print commands that are executed
Expand Down
3 changes: 2 additions & 1 deletion mypy.ini
Expand Up @@ -25,7 +25,8 @@ warn_unused_ignores = True
;
; Do not use wildcards in module paths, otherwise added modules will
; automatically have the same set of relaxed rules as the rest

[mypy-cohere.*]
ignore_missing_imports = True
[mypy-django.*]
ignore_missing_imports = True
[mypy-pyramid.*]
Expand Down
1 change: 1 addition & 0 deletions scripts/split-tox-gh-actions/split-tox-gh-actions.py
Expand Up @@ -70,6 +70,7 @@
"arq",
"beam",
"celery",
"cohere",
"huey",
"langchain",
"openai",
Expand Down
33 changes: 33 additions & 0 deletions sentry_sdk/consts.py
Expand Up @@ -91,6 +91,18 @@ class SPANDATA:
See: https://develop.sentry.dev/sdk/performance/span-data-conventions/
"""

AI_FREQUENCY_PENALTY = "ai.frequency_penalty"
"""
Used to reduce repetitiveness of generated tokens.
Example: 0.5
"""

AI_PRESENCE_PENALTY = "ai.presence_penalty"
"""
Used to reduce repetitiveness of generated tokens.
Example: 0.5
"""

AI_INPUT_MESSAGES = "ai.input_messages"
"""
The input messages to an LLM call.
Expand Down Expand Up @@ -164,12 +176,31 @@ class SPANDATA:
For an AI model call, the logit bias
"""

AI_PREAMBLE = "ai.preamble"
"""
For an AI model call, the preamble parameter.
Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style.
Example: "You are now a clown."
"""

AI_RAW_PROMPTING = "ai.raw_prompting"
"""
Minimize pre-processing done to the prompt sent to the LLM.
Example: true
"""

AI_RESPONSES = "ai.responses"
"""
The responses to an AI model call. Always as a list.
Example: ["hello", "world"]
"""

AI_SEED = "ai.seed"
"""
The seed, ideally models given the same seed and same other parameters will produce the exact same output.
Example: 123.45
"""

DB_NAME = "db.name"
"""
The name of the database being accessed. For commands that switch the database, this should be set to the target database (even if the command fails).
Expand Down Expand Up @@ -298,6 +329,8 @@ class SPANDATA:
class OP:
ANTHROPIC_MESSAGES_CREATE = "ai.messages.create.anthropic"
CACHE_GET_ITEM = "cache.get_item"
COHERE_CHAT_COMPLETIONS_CREATE = "ai.chat_completions.create.cohere"
COHERE_EMBEDDINGS_CREATE = "ai.embeddings.create.cohere"
DB = "db"
DB_REDIS = "db.redis"
EVENT_DJANGO = "event.django"
Expand Down
1 change: 1 addition & 0 deletions sentry_sdk/integrations/__init__.py
Expand Up @@ -78,6 +78,7 @@ def iter_default_integrations(with_auto_enabling_integrations):
"sentry_sdk.integrations.celery.CeleryIntegration",
"sentry_sdk.integrations.chalice.ChaliceIntegration",
"sentry_sdk.integrations.clickhouse_driver.ClickhouseDriverIntegration",
"sentry_sdk.integrations.cohere.CohereIntegration",
"sentry_sdk.integrations.django.DjangoIntegration",
"sentry_sdk.integrations.falcon.FalconIntegration",
"sentry_sdk.integrations.fastapi.FastApiIntegration",
Expand Down
257 changes: 257 additions & 0 deletions sentry_sdk/integrations/cohere.py
@@ -0,0 +1,257 @@
from functools import wraps

from sentry_sdk import consts
from sentry_sdk._types import TYPE_CHECKING
from sentry_sdk.ai.monitoring import record_token_usage
from sentry_sdk.consts import SPANDATA
from sentry_sdk.ai.utils import set_data_normalized

if TYPE_CHECKING:
from typing import Any, Callable, Iterator
from sentry_sdk.tracing import Span

import sentry_sdk
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.utils import (
capture_internal_exceptions,
event_from_exception,
ensure_integration_enabled,
)

try:
from cohere.client import Client
from cohere.base_client import BaseCohere
from cohere import ChatStreamEndEvent, NonStreamedChatResponse

if TYPE_CHECKING:
from cohere import StreamedChatResponse
except ImportError:
raise DidNotEnable("Cohere not installed")


COLLECTED_CHAT_PARAMS = {
"model": SPANDATA.AI_MODEL_ID,
"k": SPANDATA.AI_TOP_K,
"p": SPANDATA.AI_TOP_P,
"seed": SPANDATA.AI_SEED,
"frequency_penalty": SPANDATA.AI_FREQUENCY_PENALTY,
"presence_penalty": SPANDATA.AI_PRESENCE_PENALTY,
"raw_prompting": SPANDATA.AI_RAW_PROMPTING,
}

COLLECTED_PII_CHAT_PARAMS = {
"tools": SPANDATA.AI_TOOLS,
"preamble": SPANDATA.AI_PREAMBLE,
}

COLLECTED_CHAT_RESP_ATTRS = {
"generation_id": "ai.generation_id",
"is_search_required": "ai.is_search_required",
"finish_reason": "ai.finish_reason",
}

COLLECTED_PII_CHAT_RESP_ATTRS = {
"citations": "ai.citations",
"documents": "ai.documents",
"search_queries": "ai.search_queries",
"search_results": "ai.search_results",
"tool_calls": "ai.tool_calls",
}


class CohereIntegration(Integration):
identifier = "cohere"

def __init__(self, include_prompts=True):
# type: (CohereIntegration, bool) -> None
self.include_prompts = include_prompts

@staticmethod
def setup_once():
# type: () -> None
BaseCohere.chat = _wrap_chat(BaseCohere.chat, streaming=False)
Client.embed = _wrap_embed(Client.embed)
BaseCohere.chat_stream = _wrap_chat(BaseCohere.chat_stream, streaming=True)


def _capture_exception(exc):
# type: (Any) -> None
event, hint = event_from_exception(
exc,
client_options=sentry_sdk.get_client().options,
mechanism={"type": "cohere", "handled": False},
)
sentry_sdk.capture_event(event, hint=hint)


def _wrap_chat(f, streaming):
# type: (Callable[..., Any], bool) -> Callable[..., Any]

def collect_chat_response_fields(span, res, include_pii):
# type: (Span, NonStreamedChatResponse, bool) -> None
if include_pii:
if hasattr(res, "text"):
set_data_normalized(
span,
SPANDATA.AI_RESPONSES,
[res.text],
)
for pii_attr in COLLECTED_PII_CHAT_RESP_ATTRS:
if hasattr(res, pii_attr):
set_data_normalized(span, "ai." + pii_attr, getattr(res, pii_attr))

for attr in COLLECTED_CHAT_RESP_ATTRS:
if hasattr(res, attr):
set_data_normalized(span, "ai." + attr, getattr(res, attr))

if hasattr(res, "meta"):
if hasattr(res.meta, "billed_units"):
record_token_usage(
span,
prompt_tokens=res.meta.billed_units.input_tokens,
completion_tokens=res.meta.billed_units.output_tokens,
)
elif hasattr(res.meta, "tokens"):
record_token_usage(
span,
prompt_tokens=res.meta.tokens.input_tokens,
completion_tokens=res.meta.tokens.output_tokens,
)

if hasattr(res.meta, "warnings"):
set_data_normalized(span, "ai.warnings", res.meta.warnings)

@wraps(f)
@ensure_integration_enabled(CohereIntegration, f)
def new_chat(*args, **kwargs):
# type: (*Any, **Any) -> Any
if "message" not in kwargs:
return f(*args, **kwargs)

if not isinstance(kwargs.get("message"), str):
return f(*args, **kwargs)

message = kwargs.get("message")

span = sentry_sdk.start_span(
op=consts.OP.COHERE_CHAT_COMPLETIONS_CREATE,
description="cohere.client.Chat",
)
span.__enter__()
try:
res = f(*args, **kwargs)
except Exception as e:
_capture_exception(e)
span.__exit__(None, None, None)
raise e from None

integration = sentry_sdk.get_client().get_integration(CohereIntegration)

with capture_internal_exceptions():
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
span,
SPANDATA.AI_INPUT_MESSAGES,
list(
map(
lambda x: {
"role": getattr(x, "role", "").lower(),
"content": getattr(x, "message", ""),
},
kwargs.get("chat_history", []),
)
)
+ [{"role": "user", "content": message}],
)
for k, v in COLLECTED_PII_CHAT_PARAMS.items():
if k in kwargs:
set_data_normalized(span, v, kwargs[k])

for k, v in COLLECTED_CHAT_PARAMS.items():
if k in kwargs:
set_data_normalized(span, v, kwargs[k])
set_data_normalized(span, SPANDATA.AI_STREAMING, False)

if streaming:
old_iterator = res

def new_iterator():
# type: () -> Iterator[StreamedChatResponse]

with capture_internal_exceptions():
for x in old_iterator:
if isinstance(x, ChatStreamEndEvent):
collect_chat_response_fields(
span,
x.response,
include_pii=should_send_default_pii()
and integration.include_prompts,
)
yield x

span.__exit__(None, None, None)

return new_iterator()
elif isinstance(res, NonStreamedChatResponse):
collect_chat_response_fields(
span,
res,
include_pii=should_send_default_pii()
and integration.include_prompts,
)
span.__exit__(None, None, None)
else:
set_data_normalized(span, "unknown_response", True)
span.__exit__(None, None, None)
return res

return new_chat


def _wrap_embed(f):
# type: (Callable[..., Any]) -> Callable[..., Any]

@wraps(f)
@ensure_integration_enabled(CohereIntegration, f)
def new_embed(*args, **kwargs):
# type: (*Any, **Any) -> Any
with sentry_sdk.start_span(
op=consts.OP.COHERE_EMBEDDINGS_CREATE,
description="Cohere Embedding Creation",
) as span:
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
if "texts" in kwargs and (
should_send_default_pii() and integration.include_prompts
):
if isinstance(kwargs["texts"], str):
set_data_normalized(span, "ai.texts", [kwargs["texts"]])
elif (
isinstance(kwargs["texts"], list)
and len(kwargs["texts"]) > 0
and isinstance(kwargs["texts"][0], str)
):
set_data_normalized(
span, SPANDATA.AI_INPUT_MESSAGES, kwargs["texts"]
)

if "model" in kwargs:
set_data_normalized(span, SPANDATA.AI_MODEL_ID, kwargs["model"])
try:
res = f(*args, **kwargs)
except Exception as e:
_capture_exception(e)
raise e from None
if (
hasattr(res, "meta")
and hasattr(res.meta, "billed_units")
and hasattr(res.meta.billed_units, "input_tokens")
):
record_token_usage(
span,
prompt_tokens=res.meta.billed_units.input_tokens,
total_tokens=res.meta.billed_units.input_tokens,
)
return res

return new_embed
8 changes: 7 additions & 1 deletion sentry_sdk/integrations/langchain.py
Expand Up @@ -63,7 +63,12 @@ def count_tokens(s):

# To avoid double collecting tokens, we do *not* measure
# token counts for models for which we have an explicit integration
NO_COLLECT_TOKEN_MODELS = ["openai-chat"] # TODO add huggingface and anthropic
NO_COLLECT_TOKEN_MODELS = [
"openai-chat",
"anthropic-chat",
"cohere-chat",
"huggingface_endpoint",
]


class LangchainIntegration(Integration):
Expand Down Expand Up @@ -216,6 +221,7 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
watched_span.no_collect_tokens = any(
x in all_params.get("_type", "") for x in NO_COLLECT_TOKEN_MODELS
)

if not model and "anthropic" in all_params.get("_type"):
model = "claude-2"
if model:
Expand Down
3 changes: 3 additions & 0 deletions tests/integrations/cohere/__init__.py
@@ -0,0 +1,3 @@
import pytest

pytest.importorskip("cohere")

0 comments on commit 40746ef

Please sign in to comment.