-
Notifications
You must be signed in to change notification settings - Fork 24
/
tracing.py
251 lines (216 loc) · 8.6 KB
/
tracing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import json
import logging
import re
from typing import Awaitable, Callable, List, Optional, Tuple, Union, cast
from fastapi import Request, Response
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace.samplers import ProbabilitySampler
from opencensus.trace.span import SpanKind
from opencensus.trace.tracer import Tracer
from pccommon.config import get_apis_config
from pccommon.constants import (
HTTP_METHOD,
HTTP_PATH,
HTTP_STATUS_CODE,
HTTP_URL,
X_AZURE_REF,
X_REQUEST_ENTITY,
)
from pccommon.logging import request_to_path
from pccommon.utils import get_request_ip
_config = get_apis_config()
logger = logging.getLogger(__name__)
exporter = (
AzureExporter(
connection_string=(
f"InstrumentationKey={_config.app_insights_instrumentation_key}"
)
)
if _config.app_insights_instrumentation_key
else None
)
is_trace_enabled = exporter is not None
async def trace_request(
service_name: str,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
"""Construct a request trace with custom dimensions"""
request_path = request_to_path(request).strip("/")
if _should_trace_request(request):
tracer = Tracer(
exporter=exporter,
sampler=ProbabilitySampler(1.0),
)
with tracer.span("main") as span:
(collection_id, item_id) = await _collection_item_from_request(
service_name, request
)
span.span_kind = SpanKind.SERVER
# Throwing the main span into request state lets us create child spans
# in downstream request processing, if there are specific things that
# are slow.
request.state.parent_span = span
# Add request dimensions to the trace prior to calling the next middleware
tracer.add_attribute_to_current_span(
attribute_key="ref_id",
attribute_value=request.headers.get(X_AZURE_REF),
)
tracer.add_attribute_to_current_span(
attribute_key="request_entity",
attribute_value=request.headers.get(X_REQUEST_ENTITY),
)
tracer.add_attribute_to_current_span(
attribute_key="request_ip",
attribute_value=get_request_ip(request),
)
tracer.add_attribute_to_current_span(
attribute_key=HTTP_METHOD, attribute_value=str(request.method)
)
tracer.add_attribute_to_current_span(
attribute_key=HTTP_URL, attribute_value=str(request.url)
)
tracer.add_attribute_to_current_span(
attribute_key=HTTP_PATH, attribute_value=request_path
)
tracer.add_attribute_to_current_span(
attribute_key="service", attribute_value=service_name
)
tracer.add_attribute_to_current_span(
attribute_key="in-server", attribute_value="true"
)
if collection_id is not None:
tracer.add_attribute_to_current_span(
attribute_key="collection", attribute_value=collection_id
)
if item_id is not None:
tracer.add_attribute_to_current_span(
attribute_key="item", attribute_value=item_id
)
# Call next middleware
response = await call_next(request)
# Include response dimensions in the trace
tracer.add_attribute_to_current_span(
attribute_key=HTTP_STATUS_CODE, attribute_value=response.status_code
)
return response
else:
return await call_next(request)
collection_id_re = re.compile(
r".*/collections/?(?P<collection_id>[a-zA-Z0-9\-\%]+)?(/items/(?P<item_id>.*))?.*" # noqa
)
async def _collection_item_from_request(
service_name: str,
request: Request,
) -> Tuple[Optional[str], Optional[str]]:
"""Attempt to get collection and item ids from the request path or querystring."""
url = request.url
path = url.path.strip("/")
try:
collection_id_match = collection_id_re.match(f"{url}")
if collection_id_match:
collection_id = cast(
Optional[str], collection_id_match.group("collection_id")
)
item_id = cast(Optional[str], collection_id_match.group("item_id"))
return (collection_id, item_id)
elif path.endswith("/search") or path.endswith("/register"):
return await _parse_collection_from_search(request)
else:
collection_id = request.query_params.get("collection")
# Some endpoints, like preview/, take an `items` parameter, but
# conventionally it is a single item id
item_id = request.query_params.get("item") or request.query_params.get(
"items"
)
return (collection_id, item_id)
except Exception as e:
logger.exception(e)
return (None, None)
def _should_trace_request(request: Request) -> bool:
"""
Determine if we should trace a request.
- Not a HEAD request
- Not a health check endpoint
"""
return (
is_trace_enabled
and request.method.lower() != "head"
and not request.url.path.strip("/").endswith("_mgmt/ping")
)
async def _parse_collection_from_search(
request: Request,
) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the collection id from a search request.
The search endpoint is a bit of a special case. If it's a GET, the collection
and item ids are in the querystring. If it's a POST, the collection and item may
be in either a CQL-JSON or CQL2-JSON filter body, or a query/stac-ql body.
"""
if request.method.lower() == "get":
collection_id = request.query_params.get("collections")
item_id = request.query_params.get("ids")
return (collection_id, item_id)
elif request.method.lower() == "post":
try:
body = await request.json()
if "collections" in body:
return _parse_queryjson(body)
elif "filter" in body:
return _parse_cqljson(body["filter"])
except json.JSONDecodeError:
logger.warning(
"Unable to parse search body as JSON. Ignoring collection parameter."
)
return (None, None)
def _parse_cqljson(cql: dict) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the collection id from a CQL-JSON filter.
The CQL-JSON filter is a bit of a special case. It's a JSON object in either
CQL or CQL2 syntax. Parse the object and look for the collection and item
ids. If multiple collections or items are found, format them to a CSV.
"""
collections = _iter_cql(cql, property_name="collection")
ids = _iter_cql(cql, property_name="id")
if isinstance(collections, list):
collections = ",".join(collections)
if isinstance(ids, list):
ids = ",".join(ids)
return (collections, ids)
def _parse_queryjson(query: dict) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the collection and item ids from a traditional STAC API Item Search body.
The query is a JSON object with relevant keys, "collections" and "ids".
"""
collection_ids = query.get("collections")
item_ids = query.get("ids")
# Collection and ids are List[str] per the spec,
# but the client may allow just a single item
if isinstance(collection_ids, list):
collection_ids = ",".join(collection_ids)
if isinstance(item_ids, list):
item_ids = ",".join(item_ids)
return (collection_ids, item_ids)
def _iter_cql(cql: dict, property_name: str) -> Optional[Union[str, List[str]]]:
"""
Recurse through a CQL or CQL2 filter body, returning the value of the
provided property name, if found. Typical usage will be to provide
`collection` and `id`.
"""
for _, v in cql.items():
if isinstance(v, dict):
result = _iter_cql(v, property_name)
if result is not None:
return result
elif isinstance(v, list):
for item in v:
if isinstance(item, dict):
if "property" in item:
if item["property"] == property_name:
return v[1]
else:
result = _iter_cql(item, property_name)
if result is not None:
return result
# No collection was found
return None