Skip to content

Commit

Permalink
streaming estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle-Verhoog committed Apr 22, 2023
1 parent 9287fbf commit b00b994
Show file tree
Hide file tree
Showing 4 changed files with 286 additions and 81 deletions.
11 changes: 10 additions & 1 deletion ddtrace/contrib/openai/_logging.py
@@ -1,4 +1,7 @@
import json

# logger = get_logger(__name__)
import logging
import threading
from typing import List
from typing import TypedDict
Expand All @@ -7,9 +10,10 @@
from ddtrace.internal.compat import httplib
from ddtrace.internal.logger import get_logger
from ddtrace.internal.periodic import PeriodicService
from ddtrace.internal.service import ServiceStatus


logger = get_logger(__file__)
logger = logging.getLogger(__name__)


class V2LogEvent(TypedDict):
Expand Down Expand Up @@ -54,10 +58,15 @@ def __init__(self, site, api_key, interval, timeout):
"DD-API-KEY": self._api_key,
"Content-Type": "application/json",
}
logger.debug("starting log writer to %s/%s", self._site, self._endpoint)

def enqueue(self, log):
# type: (V2LogEvent) -> None
with self._lock:
if self.status == ServiceStatus.RUNNING:
return
self.start()

if len(self._buffer) >= self._buffer_limit:
logger.warning("log buffer full (limit is %d), dropping log", self._buffer_limit)
return
Expand Down
180 changes: 144 additions & 36 deletions ddtrace/contrib/openai/patch.py
@@ -1,6 +1,8 @@
import os
import re
import sys
import time
from typing import AsyncGenerator

from ddtrace import config
from ddtrace.constants import SPAN_MEASURED_KEY
Expand Down Expand Up @@ -48,7 +50,6 @@ def __init__(self, config, openai, stats_url, site, api_key):
)
self._span_pc_sampler = RateSampler(sample_rate=config.span_prompt_completion_sample_rate)
self._log_pc_sampler = RateSampler(sample_rate=config.log_prompt_completion_sample_rate)
self._statsd._enabled = config.metrics_enabled
self._openai = openai

def is_pc_sampled_span(self, span):
Expand All @@ -64,11 +65,6 @@ def is_pc_sampled_log(self, span):
def start_log_writer(self):
self._log_writer.start()

def _ust_tags(self):
# Do this dynamically to ensure any changes to ddtrace.config.*
# are respected here.
return ["%s:%s" % (k, v) for k, v in [("env", config.env), ("version", config.version)] if v]

def trace(self, pin, endpoint, model):
"""Start an OpenAI span.
Expand All @@ -86,7 +82,7 @@ def trace(self, pin, endpoint, model):
v = getattr(self._openai, attr)
if v is not None:
if attr == "organization_id":
span.set_tag_str("organization.id", v)
span.set_tag_str("organization.id", v or "")
else:
span.set_tag_str(attr, v)
span.set_tag_str("endpoint", endpoint)
Expand Down Expand Up @@ -124,10 +120,10 @@ def _metrics_tags(self, span):
"version:%s" % (config.version or ""),
"env:%s" % (config.env or ""),
"service:%s" % (span.service or ""),
"model:%s" % span.get_tag("model"),
"endpoint:%s" % span.get_tag("endpoint"),
"organization.id:%s" % span.get_tag("organization.id"),
"organization.name:%s" % span.get_tag("organization.name"),
"model:%s" % (span.get_tag("model") or ""),
"endpoint:%s" % (span.get_tag("endpoint") or ""),
"organization.id:%s" % (span.get_tag("organization.id") or ""),
"organization.name:%s" % (span.get_tag("organization.name") or ""),
"error:%d" % span.error,
]
err_type = span.get_tag("error.type")
Expand All @@ -137,6 +133,7 @@ def _metrics_tags(self, span):

def metric(self, span, kind, name, val):
"""Set a metric using the OpenAI context from the given span."""
print(self._config.metrics_enabled)
if not self._config.metrics_enabled:
return
tags = self._metrics_tags(span)
Expand All @@ -146,9 +143,11 @@ def metric(self, span, kind, name, val):
self._statsd.increment(name, val, tags=tags)
elif kind == "gauge":
self._statsd.gauge(name, val, tags=tags)
else:
raise ValueError("Unexpected metric type %r" % kind)

def record_usage(self, span, usage):
if not usage:
if not usage or not self._config.metrics_enabled:
return
tags = self._metrics_tags(span)
for token_type in ["prompt", "completion", "total"]:
Expand Down Expand Up @@ -270,7 +269,7 @@ def _traced_endpoint(endpoint_hook, integration, openai, pin, instance, args, kw
span = integration.trace(pin, instance.OBJECT_NAME, kwargs.get("model"))
try:
# Start the hook
hook = endpoint_hook().handle_request(integration, span, args, kwargs)
hook = endpoint_hook().handle_request(pin, integration, span, args, kwargs)
hook.send(None)

resp, error = yield
Expand All @@ -283,11 +282,19 @@ def _traced_endpoint(endpoint_hook, integration, openai, pin, instance, args, kw
# Pass the response and the error to the hook
try:
hook.send((resp, error))
except StopIteration:
pass
except StopIteration as e:
if error is None:
return e.value
finally:
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
# total_tokens = span.get_tag("response.usage.total_tokens")
# # TODO: shouldn't be here for stream
# if total_tokens is None:
# prompt_tokens = span.get_tag("response.usage.prompt_tokens")
# completion_tokens = span.get_tag("response.usage.completion_tokens")
# if prompt_tokens is not None and completion_tokens is not None:
# span.set_metric("response.usage.total_tokens", prompt_tokens + completion_tokens)


def _patched_endpoint(integration, patch_hook):
Expand All @@ -305,8 +312,10 @@ def patched_endpoint(openai, pin, func, instance, args, kwargs):
finally:
try:
g.send((resp, err))
except StopIteration:
pass
except StopIteration as e:
if err is None:
# This return takes priority over `return resp`
return e.value

return patched_endpoint

Expand All @@ -327,14 +336,33 @@ async def patched_endpoint(openai, pin, func, instance, args, kwargs):
finally:
try:
g.send((resp, err))
except StopIteration:
pass
except StopIteration as e:
if err is None:
# This return takes priority over `return resp`
return e.value

return patched_endpoint


def _est_tokens(s):
# type: (str) -> int
"""Provide a very rough estimate of the number of tokens.
Approximate using the following assumptions:
1 token ~= 4 chars
1 token ~= ¾ words
Note that this function is 3x faster than tiktoken's encoding.
"""
est1 = len(s.strip()) / 4
est2 = len(s.split(" ")) * 1.25
est3 = len(re.findall(r"[\w']+|[.,!?;]", s)) * 0.75
return int((est1 + est2 + est3) / 3)


class _EndpointHook:
def handle_request(self, integration, span, args, kwargs):
def handle_request(self, pin, integration, span, args, kwargs):
raise NotImplementedError


Expand All @@ -353,6 +381,68 @@ def _record_request(self, span, kwargs):
else:
span.set_tag("request.%s" % kw_attr, kwargs[kw_attr])

def _handle_response(self, pin, span, integration, resp):
"""Handle the response object returned from endpoint calls.
This method helps with streamed responses by wrapping the generator returned with a
generator that traces the reading of the response.
"""

def shared_gen():
""" """
stream_span = pin.tracer.start_span("openai.stream", child_of=span, activate=True)
num_prompt_tokens = span.get_metric("response.usage.prompt_tokens") or 0

num_completion_tokens = yield

stream_span.set_metric("response.usage.completion_tokens", num_completion_tokens)
total_tokens = num_prompt_tokens + num_completion_tokens
stream_span.set_metric("response.usage.total_tokens", total_tokens)
integration.metric(span, "dist", "tokens.completion", num_completion_tokens)
integration.metric(span, "dist", "tokens.total", total_tokens)
stream_span.finish()

# ``span`` could be flushed here so this is a best effort to attach the metric
span.set_metric("response.usage.completion_tokens", num_completion_tokens)
span.set_metric("response.usage.total_tokens", total_tokens)

# A chunk corresponds to a token:
# https://community.openai.com/t/how-to-get-total-tokens-from-a-stream-of-completioncreaterequests/110700
# https://community.openai.com/t/openai-api-get-usage-tokens-in-response-when-set-stream-true/141866
if isinstance(resp, AsyncGenerator):

async def traced_streamed_response():
g = shared_gen()
g.send(None)
num_completion_tokens = 0
try:
async for chunk in resp:
num_completion_tokens += 1
yield chunk
finally:
try:
g.send(num_completion_tokens)
except StopIteration:
pass

else:

def traced_streamed_response():
g = shared_gen()
g.send(None)
num_completion_tokens = 0
try:
for chunk in resp:
num_completion_tokens += 1
yield chunk
finally:
try:
g.send(num_completion_tokens)
except StopIteration:
pass

return traced_streamed_response()


class _CompletionHook(_BaseCompletionHook):
_request_tag_attrs = [
Expand All @@ -372,16 +462,27 @@ class _CompletionHook(_BaseCompletionHook):
"user",
]

def handle_request(self, integration, span, args, kwargs):
def handle_request(self, pin, integration, span, args, kwargs):
sample_pc_span = integration.is_pc_sampled_span(span)

if sample_pc_span:
prompt = kwargs.get("prompt", "")
if isinstance(prompt, list):
if isinstance(prompt, str):
span.set_tag_str("request.prompt", integration.trunc(prompt))
elif prompt:
for idx, p in enumerate(prompt):
span.set_tag_str("request.prompt.%d" % idx, integration.trunc(p))
elif prompt:
span.set_tag_str("request.prompt", integration.trunc(prompt))

if "stream" in kwargs and kwargs["stream"]:
prompt = kwargs.get("prompt", "")
num_prompt_tokens = 0
if isinstance(prompt, str):
num_prompt_tokens += _est_tokens(prompt)
else:
for p in prompt:
num_prompt_tokens += _est_tokens(p)
span.set_metric("response.usage.prompt_tokens", num_prompt_tokens)
integration.metric(span, "dist", "tokens.prompt", num_prompt_tokens)

self._record_request(span, kwargs)

Expand Down Expand Up @@ -412,6 +513,7 @@ def handle_request(self, integration, span, args, kwargs):
"choices": resp["choices"] if resp and "choices" in resp else [],
},
)
return self._handle_response(pin, span, integration, resp)


class _ChatCompletionHook(_BaseCompletionHook):
Expand All @@ -428,22 +530,22 @@ class _ChatCompletionHook(_BaseCompletionHook):
"user",
]

def handle_request(self, integration, span, args, kwargs):
def handle_request(self, pin, integration, span, args, kwargs):
sample_pc_span = integration.is_pc_sampled_span(span)
messages = kwargs.get("messages")
if sample_pc_span and messages:

def set_message_tag(m):
for idx, m in enumerate(messages):
content = integration.trunc(m.get("content", ""))
role = integration.trunc(m.get("role", ""))
span.set_tag_str("request.messages.%d.content" % idx, content)
span.set_tag_str("request.messages.%d.role" % idx, role)

if isinstance(messages, list):
for idx, message in enumerate(messages):
set_message_tag(message)
else:
set_message_tag(messages)
if "stream" in kwargs and kwargs["stream"]:
num_message_tokens = 0
for m in messages:
num_message_tokens += _est_tokens(m.get("content", ""))
span.set_metric("response.usage.prompt_tokens", num_message_tokens)
integration.metric(span, "dist", "tokens.prompt", num_message_tokens)

self._record_request(span, kwargs)

Expand Down Expand Up @@ -481,10 +583,11 @@ def set_message_tag(m):
"completion": completions,
},
)
return self._handle_response(pin, span, integration, resp)


class _EmbeddingHook(_EndpointHook):
def handle_request(self, integration, span, args, kwargs):
def handle_request(self, pin, integration, span, args, kwargs):
for kw_attr in ["model", "input", "user"]:
if kw_attr in kwargs:
if kw_attr == "input" and isinstance(kwargs["input"], list):
Expand All @@ -509,10 +612,15 @@ def _patched_convert(integration):
def patched_convert(openai, pin, func, instance, args, kwargs):
"""Patch convert captures header information in the openai response"""
for val in args:
# FIXME these are reported for each chunk
# this is a signal to avoid repeating these calls for each
# TODO: need a better signal
span = pin.tracer.current_span()
if not span:
return func(*args, **kwargs)
if span.get_tag("organization.name") is not None:
continue
if isinstance(val, openai.openai_response.OpenAIResponse):
span = pin.tracer.current_span()
if not span:
return func(*args, **kwargs)
val = val._headers
if val.get("openai-organization"):
org_name = val.get("openai-organization")
Expand Down

0 comments on commit b00b994

Please sign in to comment.