Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(ai-monitoring): Cohere integration (#3055)
* Cohere integration * Fix lint * Fix bug with model ID not being pulled * Exclude known models from langchain * tox.ini * Removed print statement * Apply suggestions from code review Co-authored-by: Anton Pirker <anton.pirker@sentry.io> --------- Co-authored-by: Anton Pirker <anton.pirker@sentry.io>
- Loading branch information
1 parent
1a32183
commit 40746ef
Showing
10 changed files
with
523 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,7 @@ | |
"arq", | ||
"beam", | ||
"celery", | ||
"cohere", | ||
"huey", | ||
"langchain", | ||
"openai", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
from functools import wraps | ||
|
||
from sentry_sdk import consts | ||
from sentry_sdk._types import TYPE_CHECKING | ||
from sentry_sdk.ai.monitoring import record_token_usage | ||
from sentry_sdk.consts import SPANDATA | ||
from sentry_sdk.ai.utils import set_data_normalized | ||
|
||
if TYPE_CHECKING: | ||
from typing import Any, Callable, Iterator | ||
from sentry_sdk.tracing import Span | ||
|
||
import sentry_sdk | ||
from sentry_sdk.scope import should_send_default_pii | ||
from sentry_sdk.integrations import DidNotEnable, Integration | ||
from sentry_sdk.utils import ( | ||
capture_internal_exceptions, | ||
event_from_exception, | ||
ensure_integration_enabled, | ||
) | ||
|
||
try: | ||
from cohere.client import Client | ||
from cohere.base_client import BaseCohere | ||
from cohere import ChatStreamEndEvent, NonStreamedChatResponse | ||
|
||
if TYPE_CHECKING: | ||
from cohere import StreamedChatResponse | ||
except ImportError: | ||
raise DidNotEnable("Cohere not installed") | ||
|
||
|
||
COLLECTED_CHAT_PARAMS = { | ||
"model": SPANDATA.AI_MODEL_ID, | ||
"k": SPANDATA.AI_TOP_K, | ||
"p": SPANDATA.AI_TOP_P, | ||
"seed": SPANDATA.AI_SEED, | ||
"frequency_penalty": SPANDATA.AI_FREQUENCY_PENALTY, | ||
"presence_penalty": SPANDATA.AI_PRESENCE_PENALTY, | ||
"raw_prompting": SPANDATA.AI_RAW_PROMPTING, | ||
} | ||
|
||
COLLECTED_PII_CHAT_PARAMS = { | ||
"tools": SPANDATA.AI_TOOLS, | ||
"preamble": SPANDATA.AI_PREAMBLE, | ||
} | ||
|
||
COLLECTED_CHAT_RESP_ATTRS = { | ||
"generation_id": "ai.generation_id", | ||
"is_search_required": "ai.is_search_required", | ||
"finish_reason": "ai.finish_reason", | ||
} | ||
|
||
COLLECTED_PII_CHAT_RESP_ATTRS = { | ||
"citations": "ai.citations", | ||
"documents": "ai.documents", | ||
"search_queries": "ai.search_queries", | ||
"search_results": "ai.search_results", | ||
"tool_calls": "ai.tool_calls", | ||
} | ||
|
||
|
||
class CohereIntegration(Integration): | ||
identifier = "cohere" | ||
|
||
def __init__(self, include_prompts=True): | ||
# type: (CohereIntegration, bool) -> None | ||
self.include_prompts = include_prompts | ||
|
||
@staticmethod | ||
def setup_once(): | ||
# type: () -> None | ||
BaseCohere.chat = _wrap_chat(BaseCohere.chat, streaming=False) | ||
Client.embed = _wrap_embed(Client.embed) | ||
BaseCohere.chat_stream = _wrap_chat(BaseCohere.chat_stream, streaming=True) | ||
|
||
|
||
def _capture_exception(exc): | ||
# type: (Any) -> None | ||
event, hint = event_from_exception( | ||
exc, | ||
client_options=sentry_sdk.get_client().options, | ||
mechanism={"type": "cohere", "handled": False}, | ||
) | ||
sentry_sdk.capture_event(event, hint=hint) | ||
|
||
|
||
def _wrap_chat(f, streaming): | ||
# type: (Callable[..., Any], bool) -> Callable[..., Any] | ||
|
||
def collect_chat_response_fields(span, res, include_pii): | ||
# type: (Span, NonStreamedChatResponse, bool) -> None | ||
if include_pii: | ||
if hasattr(res, "text"): | ||
set_data_normalized( | ||
span, | ||
SPANDATA.AI_RESPONSES, | ||
[res.text], | ||
) | ||
for pii_attr in COLLECTED_PII_CHAT_RESP_ATTRS: | ||
if hasattr(res, pii_attr): | ||
set_data_normalized(span, "ai." + pii_attr, getattr(res, pii_attr)) | ||
|
||
for attr in COLLECTED_CHAT_RESP_ATTRS: | ||
if hasattr(res, attr): | ||
set_data_normalized(span, "ai." + attr, getattr(res, attr)) | ||
|
||
if hasattr(res, "meta"): | ||
if hasattr(res.meta, "billed_units"): | ||
record_token_usage( | ||
span, | ||
prompt_tokens=res.meta.billed_units.input_tokens, | ||
completion_tokens=res.meta.billed_units.output_tokens, | ||
) | ||
elif hasattr(res.meta, "tokens"): | ||
record_token_usage( | ||
span, | ||
prompt_tokens=res.meta.tokens.input_tokens, | ||
completion_tokens=res.meta.tokens.output_tokens, | ||
) | ||
|
||
if hasattr(res.meta, "warnings"): | ||
set_data_normalized(span, "ai.warnings", res.meta.warnings) | ||
|
||
@wraps(f) | ||
@ensure_integration_enabled(CohereIntegration, f) | ||
def new_chat(*args, **kwargs): | ||
# type: (*Any, **Any) -> Any | ||
if "message" not in kwargs: | ||
return f(*args, **kwargs) | ||
|
||
if not isinstance(kwargs.get("message"), str): | ||
return f(*args, **kwargs) | ||
|
||
message = kwargs.get("message") | ||
|
||
span = sentry_sdk.start_span( | ||
op=consts.OP.COHERE_CHAT_COMPLETIONS_CREATE, | ||
description="cohere.client.Chat", | ||
) | ||
span.__enter__() | ||
try: | ||
res = f(*args, **kwargs) | ||
except Exception as e: | ||
_capture_exception(e) | ||
span.__exit__(None, None, None) | ||
raise e from None | ||
|
||
integration = sentry_sdk.get_client().get_integration(CohereIntegration) | ||
|
||
with capture_internal_exceptions(): | ||
if should_send_default_pii() and integration.include_prompts: | ||
set_data_normalized( | ||
span, | ||
SPANDATA.AI_INPUT_MESSAGES, | ||
list( | ||
map( | ||
lambda x: { | ||
"role": getattr(x, "role", "").lower(), | ||
"content": getattr(x, "message", ""), | ||
}, | ||
kwargs.get("chat_history", []), | ||
) | ||
) | ||
+ [{"role": "user", "content": message}], | ||
) | ||
for k, v in COLLECTED_PII_CHAT_PARAMS.items(): | ||
if k in kwargs: | ||
set_data_normalized(span, v, kwargs[k]) | ||
|
||
for k, v in COLLECTED_CHAT_PARAMS.items(): | ||
if k in kwargs: | ||
set_data_normalized(span, v, kwargs[k]) | ||
set_data_normalized(span, SPANDATA.AI_STREAMING, False) | ||
|
||
if streaming: | ||
old_iterator = res | ||
|
||
def new_iterator(): | ||
# type: () -> Iterator[StreamedChatResponse] | ||
|
||
with capture_internal_exceptions(): | ||
for x in old_iterator: | ||
if isinstance(x, ChatStreamEndEvent): | ||
collect_chat_response_fields( | ||
span, | ||
x.response, | ||
include_pii=should_send_default_pii() | ||
and integration.include_prompts, | ||
) | ||
yield x | ||
|
||
span.__exit__(None, None, None) | ||
|
||
return new_iterator() | ||
elif isinstance(res, NonStreamedChatResponse): | ||
collect_chat_response_fields( | ||
span, | ||
res, | ||
include_pii=should_send_default_pii() | ||
and integration.include_prompts, | ||
) | ||
span.__exit__(None, None, None) | ||
else: | ||
set_data_normalized(span, "unknown_response", True) | ||
span.__exit__(None, None, None) | ||
return res | ||
|
||
return new_chat | ||
|
||
|
||
def _wrap_embed(f): | ||
# type: (Callable[..., Any]) -> Callable[..., Any] | ||
|
||
@wraps(f) | ||
@ensure_integration_enabled(CohereIntegration, f) | ||
def new_embed(*args, **kwargs): | ||
# type: (*Any, **Any) -> Any | ||
with sentry_sdk.start_span( | ||
op=consts.OP.COHERE_EMBEDDINGS_CREATE, | ||
description="Cohere Embedding Creation", | ||
) as span: | ||
integration = sentry_sdk.get_client().get_integration(CohereIntegration) | ||
if "texts" in kwargs and ( | ||
should_send_default_pii() and integration.include_prompts | ||
): | ||
if isinstance(kwargs["texts"], str): | ||
set_data_normalized(span, "ai.texts", [kwargs["texts"]]) | ||
elif ( | ||
isinstance(kwargs["texts"], list) | ||
and len(kwargs["texts"]) > 0 | ||
and isinstance(kwargs["texts"][0], str) | ||
): | ||
set_data_normalized( | ||
span, SPANDATA.AI_INPUT_MESSAGES, kwargs["texts"] | ||
) | ||
|
||
if "model" in kwargs: | ||
set_data_normalized(span, SPANDATA.AI_MODEL_ID, kwargs["model"]) | ||
try: | ||
res = f(*args, **kwargs) | ||
except Exception as e: | ||
_capture_exception(e) | ||
raise e from None | ||
if ( | ||
hasattr(res, "meta") | ||
and hasattr(res.meta, "billed_units") | ||
and hasattr(res.meta.billed_units, "input_tokens") | ||
): | ||
record_token_usage( | ||
span, | ||
prompt_tokens=res.meta.billed_units.input_tokens, | ||
total_tokens=res.meta.billed_units.input_tokens, | ||
) | ||
return res | ||
|
||
return new_embed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
import pytest | ||
|
||
pytest.importorskip("cohere") |
Oops, something went wrong.