Skip to content

Commit

Permalink
streaming estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle-Verhoog committed Apr 22, 2023
1 parent 9287fbf commit bff7ad7
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 33 deletions.
11 changes: 10 additions & 1 deletion ddtrace/contrib/openai/_logging.py
@@ -1,4 +1,7 @@
import json

# logger = get_logger(__name__)
import logging
import threading
from typing import List
from typing import TypedDict
Expand All @@ -7,9 +10,10 @@
from ddtrace.internal.compat import httplib
from ddtrace.internal.logger import get_logger
from ddtrace.internal.periodic import PeriodicService
from ddtrace.internal.service import ServiceStatus


logger = get_logger(__file__)
logger = logging.getLogger(__name__)


class V2LogEvent(TypedDict):
Expand Down Expand Up @@ -54,10 +58,15 @@ def __init__(self, site, api_key, interval, timeout):
"DD-API-KEY": self._api_key,
"Content-Type": "application/json",
}
logger.debug("starting log writer to %s/%s", self._site, self._endpoint)

def enqueue(self, log):
# type: (V2LogEvent) -> None
with self._lock:
if self.status == ServiceStatus.RUNNING:
return
self.start()

if len(self._buffer) >= self._buffer_limit:
logger.warning("log buffer full (limit is %d), dropping log", self._buffer_limit)
return
Expand Down
128 changes: 101 additions & 27 deletions ddtrace/contrib/openai/patch.py
@@ -1,4 +1,6 @@
from collections.abc import Iterable
import os
import re
import sys
import time

Expand Down Expand Up @@ -86,7 +88,7 @@ def trace(self, pin, endpoint, model):
v = getattr(self._openai, attr)
if v is not None:
if attr == "organization_id":
span.set_tag_str("organization.id", v)
span.set_tag_str("organization.id", v or "")
else:
span.set_tag_str(attr, v)
span.set_tag_str("endpoint", endpoint)
Expand Down Expand Up @@ -124,10 +126,10 @@ def _metrics_tags(self, span):
"version:%s" % (config.version or ""),
"env:%s" % (config.env or ""),
"service:%s" % (span.service or ""),
"model:%s" % span.get_tag("model"),
"endpoint:%s" % span.get_tag("endpoint"),
"organization.id:%s" % span.get_tag("organization.id"),
"organization.name:%s" % span.get_tag("organization.name"),
"model:%s" % (span.get_tag("model") or ""),
"endpoint:%s" % (span.get_tag("endpoint") or ""),
"organization.id:%s" % (span.get_tag("organization.id") or ""),
"organization.name:%s" % (span.get_tag("organization.name") or ""),
"error:%d" % span.error,
]
err_type = span.get_tag("error.type")
Expand All @@ -146,9 +148,11 @@ def metric(self, span, kind, name, val):
self._statsd.increment(name, val, tags=tags)
elif kind == "gauge":
self._statsd.gauge(name, val, tags=tags)
else:
raise ValueError("Unexpected metric type %r" % kind)

def record_usage(self, span, usage):
if not usage:
if not usage or not self._config.metrics_enabled:
return
tags = self._metrics_tags(span)
for token_type in ["prompt", "completion", "total"]:
Expand Down Expand Up @@ -270,7 +274,7 @@ def _traced_endpoint(endpoint_hook, integration, openai, pin, instance, args, kw
span = integration.trace(pin, instance.OBJECT_NAME, kwargs.get("model"))
try:
# Start the hook
hook = endpoint_hook().handle_request(integration, span, args, kwargs)
hook = endpoint_hook().handle_request(pin, integration, span, args, kwargs)
hook.send(None)

resp, error = yield
Expand All @@ -283,11 +287,19 @@ def _traced_endpoint(endpoint_hook, integration, openai, pin, instance, args, kw
# Pass the response and the error to the hook
try:
hook.send((resp, error))
except StopIteration:
pass
except StopIteration as e:
if error is None:
return e.value
finally:
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
# total_tokens = span.get_tag("response.usage.total_tokens")
# # TODO: shouldn't be here for stream
# if total_tokens is None:
# prompt_tokens = span.get_tag("response.usage.prompt_tokens")
# completion_tokens = span.get_tag("response.usage.completion_tokens")
# if prompt_tokens is not None and completion_tokens is not None:
# span.set_metric("response.usage.total_tokens", prompt_tokens + completion_tokens)


def _patched_endpoint(integration, patch_hook):
Expand All @@ -304,7 +316,10 @@ def patched_endpoint(openai, pin, func, instance, args, kwargs):
raise
finally:
try:
g.send((resp, err))
r = g.send((resp, err))
if err is None:
# This return takes priority over `return resp`
return r
except StopIteration:
pass

Expand All @@ -327,14 +342,33 @@ async def patched_endpoint(openai, pin, func, instance, args, kwargs):
finally:
try:
g.send((resp, err))
except StopIteration:
pass
except StopIteration as e:
if err is None:
# This return takes priority over `return resp`
return e.value

return patched_endpoint


def _est_tokens(s):
# type: (str) -> int
"""Provide a very rough estimate of the number of tokens.
Approximate using the following assumptions:
1 token ~= 4 chars
1 token ~= ¾ words
Note that this function is 3x faster than tiktoken's encoding.
"""
est1 = len(s.strip()) / 4
est2 = len(s.split(" ")) * 1.25
est3 = len(re.findall(r"[\w']+|[.,!?;]", s)) * 0.75
return int((est1 + est2 + est3) / 3)


class _EndpointHook:
def handle_request(self, integration, span, args, kwargs):
def handle_request(self, pin, integration, span, args, kwargs):
raise NotImplementedError


Expand All @@ -353,6 +387,28 @@ def _record_request(self, span, kwargs):
else:
span.set_tag("request.%s" % kw_attr, kwargs[kw_attr])

def _handle_response(self, pin, span, integration, resp):
async def wrap_gen():
num_completion_tokens = 0
stream_span = pin.tracer.start_span("openai.stream", child_of=span, activate=True)
num_prompt_tokens = span.get_tag("response.usage.prompt_tokens") or 0
try:
async for chunk in resp:
num_completion_tokens += 1
yield chunk # TODO: yield from?
finally:
stream_span.set_metric("response.usage.completion_tokens", num_completion_tokens)
total_tokens = num_prompt_tokens + num_completion_tokens
stream_span.set_metric("response.usage.total_tokens", total_tokens)
integration.metric(span, "dist", "tokens.completion", num_completion_tokens)
integration.metric(span, "dist", "tokens.total", total_tokens)
# ``span`` could be flushed here so this is a best effort to attach the metric
span.set_metric("response.usage.completion_tokens", num_completion_tokens)
span.set_metric("response.usage.total_tokens", total_tokens)
stream_span.finish()

return wrap_gen()


class _CompletionHook(_BaseCompletionHook):
_request_tag_attrs = [
Expand All @@ -372,17 +428,28 @@ class _CompletionHook(_BaseCompletionHook):
"user",
]

def handle_request(self, integration, span, args, kwargs):
def handle_request(self, pin, integration, span, args, kwargs):
sample_pc_span = integration.is_pc_sampled_span(span)

if sample_pc_span:
prompt = kwargs.get("prompt", "")
if isinstance(prompt, list):
if isinstance(prompt, Iterable):
for idx, p in enumerate(prompt):
span.set_tag_str("request.prompt.%d" % idx, integration.trunc(p))
elif prompt:
span.set_tag_str("request.prompt", integration.trunc(prompt))

if "stream" in kwargs and kwargs["stream"]:
prompt = kwargs.get("prompt", "")
num_prompt_tokens = 0
if isinstance(prompt, str):
num_prompt_tokens += _est_tokens(prompt)
else:
for p in prompt:
num_prompt_tokens += _est_tokens(p)
span.set_metric("response.usage.prompt_tokens", num_prompt_tokens)
integration.metric(span, "dist", "tokens.prompt", num_prompt_tokens)

self._record_request(span, kwargs)

resp, error = yield
Expand Down Expand Up @@ -412,6 +479,7 @@ def handle_request(self, integration, span, args, kwargs):
"choices": resp["choices"] if resp and "choices" in resp else [],
},
)
return self._handle_response(pin, span, integration, resp)


class _ChatCompletionHook(_BaseCompletionHook):
Expand All @@ -428,22 +496,22 @@ class _ChatCompletionHook(_BaseCompletionHook):
"user",
]

def handle_request(self, integration, span, args, kwargs):
def handle_request(self, pin, integration, span, args, kwargs):
sample_pc_span = integration.is_pc_sampled_span(span)
messages = kwargs.get("messages")
if sample_pc_span and messages:

def set_message_tag(m):
for idx, m in enumerate(messages):
content = integration.trunc(m.get("content", ""))
role = integration.trunc(m.get("role", ""))
span.set_tag_str("request.messages.%d.content" % idx, content)
span.set_tag_str("request.messages.%d.role" % idx, role)

if isinstance(messages, list):
for idx, message in enumerate(messages):
set_message_tag(message)
else:
set_message_tag(messages)
if "stream" in kwargs and kwargs["stream"]:
num_message_tokens = 0
for m in messages:
num_message_tokens += _est_tokens(m.get("content", ""))
span.set_metric("response.usage.prompt_tokens", num_message_tokens)
integration.metric(span, "dist", "tokens.prompt", num_message_tokens)

self._record_request(span, kwargs)

Expand Down Expand Up @@ -481,10 +549,11 @@ def set_message_tag(m):
"completion": completions,
},
)
return self._handle_response(pin, span, integration, resp)


class _EmbeddingHook(_EndpointHook):
def handle_request(self, integration, span, args, kwargs):
def handle_request(self, pin, integration, span, args, kwargs):
for kw_attr in ["model", "input", "user"]:
if kw_attr in kwargs:
if kw_attr == "input" and isinstance(kwargs["input"], list):
Expand All @@ -509,10 +578,15 @@ def _patched_convert(integration):
def patched_convert(openai, pin, func, instance, args, kwargs):
"""Patch convert captures header information in the openai response"""
for val in args:
# FIXME these are reported for each chunk
# this is a signal to avoid repeating these calls for each
# TODO: need a better signal
span = pin.tracer.current_span()
if not span:
return func(*args, **kwargs)
if span.get_tag("organization.name") is not None:
continue
if isinstance(val, openai.openai_response.OpenAIResponse):
span = pin.tracer.current_span()
if not span:
return func(*args, **kwargs)
val = val._headers
if val.get("openai-organization"):
org_name = val.get("openai-organization")
Expand Down
68 changes: 63 additions & 5 deletions tests/contrib/openai/test_openai.py
@@ -1,4 +1,5 @@
import os
from typing import AsyncGenerator
from typing import List
from typing import Optional

Expand All @@ -10,6 +11,7 @@
from ddtrace import Pin
from ddtrace import Span
from ddtrace import patch
from ddtrace.contrib.openai import _patch
from ddtrace.contrib.openai.patch import unpatch
from ddtrace.filters import TraceFilter
from tests.utils import DummyTracer
Expand Down Expand Up @@ -91,7 +93,7 @@ def snapshot_tracer(patch_openai, mock_logs, mock_metrics):
def mock_tracer(patch_openai, mock_logs, mock_metrics):
pin = Pin.get_from(openai)
mock_tracer = DummyTracer()
pin.override(openai, tracer=DummyTracer())
pin.override(openai, tracer=mock_tracer)
pin.tracer.configure(settings={"FILTERS": [FilterOrg()]})

yield mock_tracer
Expand All @@ -112,7 +114,7 @@ def test_completion(mock_metrics, mock_logs, snapshot_tracer):
"service:",
"model:ada",
"endpoint:completions",
"organization.id:None",
"organization.id:",
"organization.name:datadog-4",
"error:0",
]
Expand Down Expand Up @@ -331,11 +333,47 @@ def test_completion_stream(snapshot_tracer):
openai.Completion.create(model="ada", prompt="Hello world", stream=True)


@pytest.mark.snapshot(ignores=["meta.http.useragent"])
@pytest.mark.asyncio
async def test_completion_async_stream(snapshot_tracer):
async def test_completion_async_stream(mock_metrics, mock_tracer):
with openai_vcr.use_cassette("completion_async_streamed.yaml"):
await openai.Completion.acreate(model="ada", prompt="Hello world", stream=True)
resp = await openai.Completion.acreate(model="ada", prompt="Hello world", stream=True)

assert isinstance(resp, AsyncGenerator)
chunks = []
async for chunk in resp:
chunks.append(chunk)

traces = mock_tracer.pop_traces()
assert len(traces) == 2
t1, t2 = traces
assert len(t1) == len(t2) == 1
assert t2[0].parent_id == t1[0].span_id

expected_tags = [
"version:",
"env:",
"service:",
"model:ada",
"endpoint:completions",
"organization.id:",
"organization.name:",
"error:0",
]
print(len(chunks))
mock_metrics.assert_has_calls(
[
mock.call.distribution(
"tokens.prompt",
2,
tags=expected_tags,
),
mock.call.distribution(
"tokens.completion",
len(chunks),
tags=expected_tags,
)
]
)


@pytest.mark.snapshot(ignores=["meta.http.useragent"])
Expand Down Expand Up @@ -604,3 +642,23 @@ def _test_logs(mock_log):
assert (rate - 15) < logs < (rate + 15)

_test_logs()


def test_est_tokens():
"""
Oracle numbers come from https://platform.openai.com/tokenizer
"""
est = _patch._est_tokens
assert est("hello world") == 2
assert est("Hello world, how are you?") == 7 - 2
assert est("hello") == 1
assert est("") == 0
assert (
est(
"""
A helpful rule of thumb is that one token generally corresponds to ~4 characters of text for common English text. This translates to roughly ¾ of a word (so 100 tokens ~= 75 words).
If you need a programmatic interface for tokenizing text, check out our tiktoken package for Python. For JavaScript, the gpt-3-encoder package for node.js works for most GPT-3 models."""
)
== 75
)

0 comments on commit bff7ad7

Please sign in to comment.