Skip to content

Commit

Permalink
Add preliminary AI analytics SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
colin-sentry committed Apr 23, 2024
1 parent 5d7c4a7 commit 369f8dd
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 26 deletions.
78 changes: 78 additions & 0 deletions sentry_sdk/ai_analytics.py
@@ -0,0 +1,78 @@
from functools import wraps

from sentry_sdk import start_span
from sentry_sdk.tracing import Span
from sentry_sdk.utils import ContextVar
from sentry_sdk._types import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Optional, Callable, Any

_ai_pipeline_name = ContextVar("ai_pipeline_name", default=None)


def set_ai_pipeline_name(name):
# type: (Optional[str]) -> None
_ai_pipeline_name.set(name)


def get_ai_pipeline_name():
# type: () -> Optional[str]
return _ai_pipeline_name.get()


def ai_pipeline(description, op="ai.pipeline", **span_kwargs):
# type: (str, str, Any) -> Callable[..., Any]
def decorator(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
@wraps(f)
def wrapped(*args, **kwargs):
# type: (Any, Any) -> Any
with start_span(description=description, op=op, **span_kwargs):
_ai_pipeline_name.set(description)
res = f(*args, **kwargs)
_ai_pipeline_name.set(None)
return res

return wrapped

return decorator


def ai_run(description, op="ai.run", **span_kwargs):
# type: (str, str, Any) -> Callable[..., Any]
def decorator(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
@wraps(f)
def wrapped(*args, **kwargs):
# type: (Any, Any) -> Any
with start_span(description=description, op=op, **span_kwargs) as span:
curr_pipeline = _ai_pipeline_name.get()
if curr_pipeline:
span.set_data("ai.pipeline.name", curr_pipeline)
return f(*args, **kwargs)

return wrapped

return decorator


def record_token_usage(
span, prompt_tokens=None, completion_tokens=None, total_tokens=None
):
# type: (Span, Optional[int], Optional[int], Optional[int]) -> None
ai_pipeline_name = get_ai_pipeline_name()
if ai_pipeline_name:
span.set_data("ai.pipeline.name", ai_pipeline_name)
if prompt_tokens is not None:
span.set_measurement("ai_prompt_tokens_used", value=prompt_tokens)
if completion_tokens is not None:
span.set_measurement("ai_completion_tokens_used", value=completion_tokens)
if (
total_tokens is None
and prompt_tokens is not None
and completion_tokens is not None
):
total_tokens = prompt_tokens + completion_tokens
if total_tokens is not None:
span.set_measurement("ai_total_tokens_used", total_tokens)
20 changes: 1 addition & 19 deletions sentry_sdk/integrations/_ai_common.py
@@ -1,7 +1,7 @@
from sentry_sdk._types import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any, Optional
from typing import Any

from sentry_sdk.tracing import Span
from sentry_sdk.utils import logger
Expand Down Expand Up @@ -30,21 +30,3 @@ def set_data_normalized(span, key, value):
# type: (Span, str, Any) -> None
normalized = _normalize_data(value)
span.set_data(key, normalized)


def record_token_usage(
span, prompt_tokens=None, completion_tokens=None, total_tokens=None
):
# type: (Span, Optional[int], Optional[int], Optional[int]) -> None
if prompt_tokens is not None:
span.set_measurement("ai_prompt_tokens_used", value=prompt_tokens)
if completion_tokens is not None:
span.set_measurement("ai_completion_tokens_used", value=completion_tokens)
if (
total_tokens is None
and prompt_tokens is not None
and completion_tokens is not None
):
total_tokens = prompt_tokens + completion_tokens
if total_tokens is not None:
span.set_measurement("ai_total_tokens_used", total_tokens)
15 changes: 11 additions & 4 deletions sentry_sdk/integrations/langchain.py
Expand Up @@ -3,8 +3,9 @@

import sentry_sdk
from sentry_sdk._types import TYPE_CHECKING
from sentry_sdk.ai_analytics import set_ai_pipeline_name, record_token_usage
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.integrations._ai_common import set_data_normalized, record_token_usage
from sentry_sdk.integrations._ai_common import set_data_normalized
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.tracing import Span

Expand Down Expand Up @@ -88,6 +89,7 @@ class WatchedSpan:
num_prompt_tokens = 0 # type: int
no_collect_tokens = False # type: bool
children = [] # type: List[WatchedSpan]
is_pipeline = False # type: bool

def __init__(self, span):
# type: (Span) -> None
Expand Down Expand Up @@ -134,9 +136,6 @@ def _normalize_langchain_message(self, message):
def _create_span(self, run_id, parent_id, **kwargs):
# type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan

if "origin" not in kwargs:
kwargs["origin"] = "auto.ai.langchain"

watched_span = None # type: Optional[WatchedSpan]
if parent_id:
parent_span = self.span_map[parent_id] # type: Optional[WatchedSpan]
Expand All @@ -146,6 +145,11 @@ def _create_span(self, run_id, parent_id, **kwargs):
if watched_span is None:
watched_span = WatchedSpan(sentry_sdk.start_span(**kwargs))

if kwargs.get("op", "").startswith("ai.pipeline."):
if kwargs.get("description"):
set_ai_pipeline_name(kwargs.get("description"))
watched_span.is_pipeline = True

watched_span.span.__enter__()
self.span_map[run_id] = watched_span
self.gc_span_map()
Expand All @@ -154,6 +158,9 @@ def _create_span(self, run_id, parent_id, **kwargs):
def _exit_span(self, span_data, run_id):
# type: (SentryLangchainCallback, WatchedSpan, UUID) -> None

if span_data.is_pipeline:
set_ai_pipeline_name(None)

span_data.span.__exit__(None, None, None)
del self.span_map[run_id]

Expand Down
5 changes: 2 additions & 3 deletions sentry_sdk/integrations/openai.py
Expand Up @@ -2,8 +2,9 @@

from sentry_sdk import consts
from sentry_sdk._types import TYPE_CHECKING
from sentry_sdk.ai_analytics import record_token_usage
from sentry_sdk.consts import SPANDATA
from sentry_sdk.integrations._ai_common import set_data_normalized, record_token_usage
from sentry_sdk.integrations._ai_common import set_data_normalized

if TYPE_CHECKING:
from typing import Any, Iterable, List, Optional, Callable, Iterator
Expand Down Expand Up @@ -141,7 +142,6 @@ def new_chat_completion(*args, **kwargs):

span = sentry_sdk.start_span(
op=consts.OP.OPENAI_CHAT_COMPLETIONS_CREATE,
origin="auto.ai.openai",
description="Chat Completion",
)
span.__enter__()
Expand Down Expand Up @@ -225,7 +225,6 @@ def new_embeddings_create(*args, **kwargs):
# type: (*Any, **Any) -> Any
with sentry_sdk.start_span(
op=consts.OP.OPENAI_EMBEDDINGS_CREATE,
origin="auto.ai.openai",
description="OpenAI Embedding Creation",
) as span:
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
Expand Down

0 comments on commit 369f8dd

Please sign in to comment.