Skip to content

Commit

Permalink
Use a generator to share sync/async code
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle-Verhoog committed Apr 13, 2023
1 parent c18e64f commit 95cd8a7
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 97 deletions.
20 changes: 3 additions & 17 deletions ddtrace/contrib/django/utils.py
Expand Up @@ -103,20 +103,6 @@ def get_django_2_route(request, resolver_match):
return None


def set_tag_array(span, prefix, value):
"""Helper to set a span tag as a single value or an array"""
if not value:
return

if len(value) == 1:
if value[0]:
span.set_tag_str(prefix, value[0])
else:
for i, v in enumerate(value, start=0):
if v:
span.set_tag_str("".join((prefix, ".", str(i))), v)


def get_request_uri(request):
"""
Helper to rebuild the original request url
Expand Down Expand Up @@ -213,11 +199,11 @@ def _set_resolver_tags(pin, span, request):
resource = " ".join((request.method, handler))

span.set_tag_str("django.view", resolver_match.view_name)
set_tag_array(span, "django.namespace", resolver_match.namespaces)
trace_utils.set_tag_array(span, "django.namespace", resolver_match.namespaces)

# Django >= 2.0.0
if hasattr(resolver_match, "app_names"):
set_tag_array(span, "django.app", resolver_match.app_names)
trace_utils.set_tag_array(span, "django.app", resolver_match.app_names)

except Resolver404:
# Normalize all 404 requests into a single resource name
Expand Down Expand Up @@ -364,7 +350,7 @@ def _after_request_tags(pin, span, request, response):
else:
template_names = None

set_tag_array(span, "django.response.template", template_names)
trace_utils.set_tag_array(span, "django.response.template", template_names)

url = get_request_uri(request)

Expand Down
3 changes: 2 additions & 1 deletion ddtrace/contrib/openai/_log.py
Expand Up @@ -43,10 +43,11 @@ def log(level, msg, tags, attrs):
timestamp = datetime.datetime.now().isoformat()

log = {
"message": "%s %s %s" % (timestamp, level, msg),
"message": "%s %s" % (timestamp, msg),
"hostname": os.getenv("DD_HOSTNAME", get_hostname()),
"ddsource": "python",
"service": "openai",
"status": level,
}
if config.env:
tags.append("env:%s" % config.env)
Expand Down
5 changes: 2 additions & 3 deletions ddtrace/contrib/openai/_metrics.py
Expand Up @@ -12,9 +12,8 @@
def stats_client():
global _statsd
if _statsd is None and config.openai.metrics_enabled:
# FIXME: this currently does not consider if the tracer
# is configured to use a different hostname.
# eg. tracer.configure(host="new-hostname")
# FIXME: this currently does not consider if the tracer is configured to
# use a different hostname. eg. tracer.configure(host="new-hostname")

# FIXME: the dogstatsd client doesn't support multi-threaded usage
_statsd = get_dogstatsd_client(
Expand Down
153 changes: 88 additions & 65 deletions ddtrace/contrib/openai/patch.py
@@ -1,7 +1,9 @@
import os
import sys

from ddtrace import config
from ddtrace.contrib.trace_utils import set_flattened_tags
from ddtrace.contrib.trace_utils import set_tag_array
from ddtrace.internal.constants import COMPONENT
from ddtrace.internal.logger import get_logger
from ddtrace.internal.utils import get_argument_value
Expand All @@ -13,9 +15,7 @@
from ...pin import Pin
from ..trace_utils import wrap
from ._metrics import stats_client
from ._openai import CHAT_COMPLETIONS
from ._openai import COMPLETIONS
from ._openai import EMBEDDINGS
from ._openai import process_request
from ._openai import process_response
from ._openai import supported
Expand All @@ -29,6 +29,7 @@
"logs_enabled": asbool(os.getenv("DD_OPENAI_LOGS_ENABLED", False)),
"metrics_enabled": asbool(os.getenv("DD_OPENAI_METRICS_ENABLED", True)),
"prompt_completion_sample_rate": float(os.getenv("DD_OPENAI_PROMPT_COMPLETION_SAMPLE_RATE", 1.0)),
# TODO: truncate threshold on prompts/completions
"_default_service": "openai",
},
)
Expand All @@ -37,11 +38,6 @@
log = get_logger(__file__)


REQUEST_TAG_PREFIX = "request"
RESPONSE_TAG_PREFIX = "response"
ERROR_TAG_PREFIX = "error"


def patch():
# Avoid importing openai at the module level, eventually will be an import hook
import openai
Expand All @@ -60,35 +56,32 @@ def patch():
# wrap(openai, "api_resources.abstract.engine_api_resource.EngineAPIResource.create", patched_create(openai))
# wrap(openai, "api_resources.abstract.engine_api_resource.EngineAPIResource.acreate", patched_async_create(openai))

if supported(CHAT_COMPLETIONS):
wrap(openai, "api_resources.chat_completion.ChatCompletion.create", patched_endpoint(openai))
wrap(openai, "api_resources.chat_completion.ChatCompletion.acreate", patched_async_endpoint(openai))
# if supported(CHAT_COMPLETIONS):
# wrap(openai, "api_resources.chat_completion.ChatCompletion.create", patched_endpoint(openai))
# wrap(openai, "api_resources.chat_completion.ChatCompletion.acreate", patched_async_endpoint(openai))

if supported(COMPLETIONS):
wrap(openai, "api_resources.completion.Completion.create", patched_endpoint(openai))
wrap(openai, "api_resources.completion.Completion.acreate", patched_async_endpoint(openai))
wrap(openai, "api_resources.completion.Completion.create", patched_completion_create(openai))
wrap(openai, "api_resources.completion.Completion.acreate", patched_completion_acreate(openai))

if supported(EMBEDDINGS):
wrap(openai, "api_resources.embedding.Embedding.create", patched_endpoint(openai))
wrap(openai, "api_resources.embedding.Embedding.acreate", patched_async_endpoint(openai))
# if supported(EMBEDDINGS):
# wrap(openai, "api_resources.embedding.Embedding.create", patched_endpoint(openai))
# wrap(openai, "api_resources.embedding.Embedding.acreate", patched_async_endpoint(openai))

Pin().onto(openai)
setattr(openai, "__datadog_patch", True)

if config.openai.logs_enabled:
ddsite = os.getenv("DD_SITE", "datadoghq.com")
ddapikey = os.getenv("DD_API_KEY")
# TODO: replace with proper check
assert ddapikey

if not ddapikey:
raise ValueError("DD_API_KEY is required for sending logs from the OpenAI integration")

ddlogs.start(
site=ddsite,
api_key=ddapikey,
)
# TODO: these logs don't show up when DD_TRACE_DEBUG=1 set... same thing for all contribs?
# FIXME: these logs don't show up when DD_TRACE_DEBUG=1 set... same thing for all contribs?
log.debug("started log writer")


Expand All @@ -99,8 +92,15 @@ def unpatch():
pass


def _set_completion_request_tags(span, kwargs):
"""Set span tags for a completion or chat completion request."""
def _completion_create(openai, pin, instance, args, kwargs):
span = pin.tracer.trace(
"openai.request", resource="completions", service=trace_utils.ext_service(pin, config.openai)
)
init_openai_span(span, openai)

model = get_argument_value(args, kwargs, -1, "model")
prompt = get_argument_value(args, kwargs, -1, "prompt")
span.set_tag_str("model", model)
kw_attrs = [
"model",
"suffix",
Expand All @@ -123,65 +123,90 @@ def _set_completion_request_tags(span, kwargs):
if kw_attr in kwargs:
span.set_tag("request.%s" % kw_attr, kwargs[kw_attr])

resp, error = yield span

def _set_completion_response_tags(span, kwargs):
kw_attrs = [
"id",
"object",
"created",
"choices",
"usage",
metric_tags = [
"model:%s" % kwargs.get("model"),
"endpoint:%s" % instance.OBJECT_NAME,
"error:%d" % (1 if error else 0),
]
for kw_attr in kw_attrs:
if kw_attr in kwargs:
span.set_tag("request.%s" % kw_attr, kwargs[kw_attr])
completions = ""

if error is not None:
span.set_exc_info(*sys.exc_info())
if isinstance(error, openai.error.OpenAIError):
# TODO?: handle specific OpenAIError types
pass
stats_client().increment("error", 1, tags=metric_tags + ["error_type:%s" % error.__class__.__name__])
if resp:
if "choices" in resp:
choices = resp["choices"]
if len(choices) > 1:
completions = "\n".join(["%s: %s" % (c.get("index"), c.get("text")) for c in choices])
else:
completions = choices[0].get("text")

span.set_tag("response.choices.num", len(choices))
for choice in choices:
idx = choice["index"]
span.set_tag_str("response.choices.%d.finish_reason" % idx, choice.get("finish_reason"))
span.set_tag("response.choices.%d.logprobs" % idx, choice.get("logprobs"))
for kw_attr in ["id", "object", "usage"]:
if kw_attr in kwargs:
span.set_tag("request.%s" % kw_attr, kwargs[kw_attr])

usage_metrics(resp.get("usage"), metric_tags)

@trace_utils.with_traced_module
def patched_endpoint(openai, pin, func, instance, args, kwargs):
span = pin.tracer.trace(
"openai.create", resource="completions", service=trace_utils.ext_service(pin, config.openai)
# TODO: determine best format for multiple choices/completions
ddlogs.log(
"info" if error is None else "error",
"sampled completion",
tags=["model:%s" % kwargs.get("model")],
attrs={
"prompt": prompt,
"completion": completions, # TODO: should be completions (plural)?
},
)
model = get_argument_value(args, kwargs, -1, "model")
prompt = get_argument_value(args, kwargs, -1, "prompt")
span.set_tag_str("model", model)
_set_completion_request_tags(span, kwargs)
span.finish()
stats_client().distribution("request.duration", span.duration_ns, tags=metric_tags)


@trace_utils.with_traced_module
def patched_completion_create(openai, pin, func, instance, args, kwargs):
g = _completion_create(openai, pin, instance, args, kwargs)
g.send(None)
resp, resp_err = None, None
try:
resp = func(*args, **kwargs)
return resp
except openai.error.OpenAIError as err:
except Exception as err:
resp_err = err
raise
finally:
completions = "\n".join(["%s: %s" % (c.get("index"), c.get("text")) for c in resp.get("choices")])
ddlogs.log(
"INFO",
"sampled completion",
tags=["model:%s" % kwargs.get("model")],
attrs={
"prompt": prompt,
"completion": completions,
},
)
metric_tags = ["model:%s" % kwargs.get("model"), "endpoint:%s" % instance.OBJECT_NAME]
usage_metrics(resp.get("usage"), metric_tags)
span.finish()
stats_client().distribution("request.duration", span.duration_ns, tags=metric_tags)
try:
g.send((resp, resp_err))
except StopIteration:
# expected
pass


@trace_utils_async.with_traced_module
async def patched_async_endpoint(openai, pin, func, instance, args, kwargs):
# resource name is set to the model being used -- if that name is not found, use the engine name
span = start_endpoint_span(openai, pin, instance, args, kwargs)
async def patched_completion_acreate(openai, pin, func, instance, args, kwargs):
g = _completion_create(openai, pin, instance, args, kwargs)
g.send(None)
resp, resp_err = None, None
try:
resp = await func(*args, **kwargs)
return resp
except openai.error.OpenAIError as err:
except Exception as err:
resp_err = err
raise
finally:
finish_endpoint_span(span, resp, resp_err, openai, instance, kwargs)
try:
g.send((resp, resp_err))
except StopIteration:
# expected
pass


@trace_utils.with_traced_module
Expand All @@ -190,7 +215,7 @@ def patched_create(openai, pin, func, instance, args, kwargs):
"openai.request", resource=instance.OBJECT_NAME, service=trace_utils.ext_service(pin, config.openai)
)
try:
init_openai_span(span, openai, kwargs.get("model"))
init_openai_span(span, openai)
resp = func(*args, **kwargs)
return resp
except openai.error.OpenAIError as err:
Expand All @@ -206,7 +231,7 @@ async def patched_async_create(openai, pin, func, instance, args, kwargs):
"openai.request", resource=instance.OBJECT_NAME, service=trace_utils.ext_service(pin, config.openai)
)
try:
init_openai_span(span, openai, kwargs.get("model"))
init_openai_span(span, openai)
resp = await func(*args, **kwargs)
return resp
except openai.error.OpenAIError as err:
Expand All @@ -217,9 +242,7 @@ async def patched_async_create(openai, pin, func, instance, args, kwargs):


# set basic openai data for all openai spans
def init_openai_span(span, openai, model):
if model:
span.set_tag_str("model", model)
def init_openai_span(span, openai):
span.set_tag_str(COMPONENT, config.openai.integration_name)
if hasattr(openai, "api_base") and openai.api_base:
span.set_tag_str("api_base", openai.api_base)
Expand All @@ -231,7 +254,7 @@ def start_endpoint_span(openai, pin, instance, args, kwargs):
span = pin.tracer.trace(
"openai.create", resource=instance.OBJECT_NAME, service=trace_utils.ext_service(pin, config.openai)
)
init_openai_span(span, openai, kwargs.get("model"))
init_openai_span(span, openai)
set_flattened_tags(
span,
append_tag_prefixes([REQUEST_TAG_PREFIX], process_request(openai, instance.OBJECT_NAME, args, kwargs)),
Expand Down
15 changes: 15 additions & 0 deletions ddtrace/contrib/trace_utils.py
Expand Up @@ -605,6 +605,21 @@ def set_flattened_tags(
span.set_tag(tag, processor(v) if processor is not None else v)


def set_tag_array(span, prefix, value):
# type: (Span, str, List[str]) -> None
"""Helper to set a span tag as a single value or an array"""
if not value:
return

if len(value) == 1:
if value[0]:
span.set_tag_str(prefix, value[0])
else:
for i, v in enumerate(value, start=0):
if v:
span.set_tag_str("".join((prefix, ".", str(i))), v)


def set_user(tracer, user_id, name=None, email=None, scope=None, role=None, session_id=None, propagate=False):
# type: (Tracer, str, Optional[str], Optional[str], Optional[str], Optional[str], Optional[str], bool) -> None
"""Set user tags.
Expand Down

0 comments on commit 95cd8a7

Please sign in to comment.