Skip to content

Commit

Permalink
Merge pull request #4041 from freelawproject/4032-fix-v4-search-api-d…
Browse files Browse the repository at this point in the history
…ate-fields-sorting
  • Loading branch information
mlissner committed May 14, 2024
2 parents 442cc86 + 4860193 commit 266620b
Show file tree
Hide file tree
Showing 6 changed files with 361 additions and 72 deletions.
43 changes: 35 additions & 8 deletions cl/api/pagination.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from base64 import b64decode, b64encode
from collections import defaultdict
from urllib.parse import parse_qs, urlencode
Expand Down Expand Up @@ -92,17 +93,31 @@ class ESCursorPagination(BasePagination):
search_type = None
cursor_query_param = "cursor"
invalid_cursor_message = "Invalid cursor"
request_date = None

def initialize_context_from_request(
self, request, search_type
) -> datetime.date:
self.base_url = request.build_absolute_uri()
self.request = request
self.search_type = search_type
self.cursor = self.decode_cursor(request)

# Set the request date from the cursor or provide an initial one if
# this is the first page request.
self.request_date = (
self.cursor.request_date
if self.cursor
else datetime.datetime.now().date()
)
return self.request_date

def paginate_queryset(
self, es_list_instance: CursorESList, request: Request, view=None
) -> list[defaultdict]:
"""Paginate the Elasticsearch query and retrieve the results."""

self.base_url = request.build_absolute_uri()
self.request = request
self.es_list_instance = es_list_instance
self.search_type = self.es_list_instance.clean_data["type"]
self.cursor = self.decode_cursor(request)
self.es_list_instance.set_pagination(
self.cursor, settings.SEARCH_API_PAGE_SIZE
)
Expand Down Expand Up @@ -149,6 +164,7 @@ def get_next_link(self) -> str | None:
search_after=search_after_sort_key,
reverse=False,
search_type=self.search_type,
request_date=self.request_date,
)
return self.encode_cursor(cursor)

Expand All @@ -166,6 +182,7 @@ def get_previous_link(self) -> str | None:
search_after=reverse_search_after_sort_key,
reverse=True,
search_type=self.search_type,
request_date=self.request_date,
)
return self.encode_cursor(cursor)

Expand All @@ -179,9 +196,9 @@ def decode_cursor(self, request: Request) -> ESCursor | None:
querystring = b64decode(encoded.encode("ascii")).decode("ascii")
tokens = parse_qs(querystring, keep_blank_values=True)
search_after = tokens.get("s", None)
reverse = tokens.get("r", ["0"])[0]
reverse = bool(int(reverse))
reverse = bool(int(tokens.get("r", ["0"])[0]))
search_type = tokens.get("t", [None])[0]
request_date = tokens.get("d", [None])[0]
except (TypeError, ValueError):
raise NotFound(self.invalid_cursor_message)

Expand All @@ -190,9 +207,17 @@ def decode_cursor(self, request: Request) -> ESCursor | None:
# in the cursor doesn't match, raise an invalid cursor error to
# avoid pagination inconsistencies.
raise NotFound(self.invalid_cursor_message)
return ESCursor(
search_after=search_after, reverse=reverse, search_type=search_type

request_date = (
datetime.date.fromisoformat(request_date) if request_date else None
)
self.cursor = ESCursor(
search_after=search_after,
reverse=reverse,
search_type=search_type,
request_date=request_date,
)
return self.cursor

def encode_cursor(self, cursor: ESCursor) -> str:
"""Given a ESCursor instance, return an url with encoded cursor."""
Expand All @@ -203,6 +228,8 @@ def encode_cursor(self, cursor: ESCursor) -> str:
tokens["r"] = "1"
if cursor.search_type:
tokens["t"] = self.search_type
if cursor.request_date:
tokens["d"] = cursor.request_date

querystring = urlencode(tokens, doseq=True)
encoded = b64encode(querystring.encode("ascii")).decode("ascii")
Expand Down
125 changes: 83 additions & 42 deletions cl/lib/elasticsearch_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import datetime
import logging
import operator
import re
import time
import traceback
from copy import deepcopy
from dataclasses import fields
from datetime import date, datetime, timedelta
from functools import reduce, wraps
from typing import Any, Callable, Dict, List, Literal

Expand Down Expand Up @@ -127,8 +127,8 @@ def build_numeric_range_query(

def build_daterange_query(
field: str,
before: date,
after: date,
before: datetime.date,
after: datetime.date,
relation: Literal["INTERSECTS", "CONTAINS", "WITHIN", None] = None,
) -> list[Range]:
"""Given field name and date range limits returns ElasticSearch range query or None
Expand Down Expand Up @@ -728,7 +728,10 @@ def build_highlights_dict(


def build_custom_function_score_for_date(
query: QueryString | str, order_by: tuple[str, str], default_score: int
query: QueryString | str,
order_by: tuple[str, str],
default_score: int,
default_current_date: datetime.date | None = None,
) -> QueryString:
"""Build a custom function score query for sorting based on a date field.
Expand All @@ -751,16 +754,31 @@ def build_custom_function_score_for_date(
results based on a child document field.
:param default_score: The default score to return when the document lacks
the sort field.
:param default_current_date: The default current date to use for computing
a stable date score across pagination in the V4 Search API.
:return: The modified QueryString object with applied function score.
"""

default_current_time = None
if default_current_date:
midnight_current_date = datetime.datetime.combine(
default_current_date, datetime.time()
)
default_current_time = int(midnight_current_date.timestamp() * 1000)

sort_field, order = order_by
query = Q(
"function_score",
query=query,
script_score={
"script": {
"source": f"""
long current_time;
if (params.default_current_time != null) {{
current_time = params.default_current_time; // Use 'default_current_time' if provided
}} else {{
current_time = new Date().getTime();
}}
// Check if the document has a value for the 'sort_field'
if (doc['{sort_field}'].size() == 0) {{
return {default_score}; // If not, return 'default_score' as the score
Expand All @@ -769,7 +787,7 @@ def build_custom_function_score_for_date(
// (February 22, 1732)
long washington_bd_offset = 7506086400000L;
// Get the current time in milliseconds, include the washington_bd_offset to work with positive epoch times.
long current_time = new Date().getTime() + washington_bd_offset;
current_time = current_time + washington_bd_offset;
// Convert the 'sort_field' value to epoch milliseconds, adjusting by the same offset.
long date_filed_time = doc['{sort_field}'].value.toInstant().toEpochMilli() + washington_bd_offset;
Expand All @@ -788,7 +806,11 @@ def build_custom_function_score_for_date(
}}
""",
# Parameters passed to the script
"params": {"order": order, "default_score": default_score},
"params": {
"order": order,
"default_score": default_score,
"default_current_time": default_current_time,
},
},
},
# Replace the original score with the one computed by the script
Expand All @@ -805,7 +827,7 @@ def build_has_child_query(
highlighting_fields: dict[str, int] | None = None,
order_by: tuple[str, str] | None = None,
child_highlighting: bool = True,
api_version: Literal["v3", "v4"] | None = None,
default_current_date: datetime.date | None = None,
) -> QueryString:
"""Build a 'has_child' query.
Expand All @@ -816,17 +838,23 @@ def build_has_child_query(
:param order_by: If provided the field to use to compute score for sorting
results based on a child document field.
:param child_highlighting: Whether highlighting should be enabled in child docs.
:param api_version: Optional, the request API version.
:param default_current_date: The default current date to use for computing
a stable date score across pagination in the V4 Search API.
:return: The 'has_child' query.
"""

if order_by and all(order_by) and child_type == "recap_document":
if api_version == "v4":
query = nullify_query_score(query)
else:
query = build_custom_function_score_for_date(
query, order_by, default_score=1
)
if (
order_by
and all(order_by)
and child_type == "recap_document"
and order_by[0] == "entry_date_filed"
):
query = build_custom_function_score_for_date(
query,
order_by,
default_score=1,
default_current_date=default_current_date,
)

highlight_options, fields_to_exclude = build_highlights_dict(
highlighting_fields, SEARCH_HL_TAG, child_highlighting
Expand Down Expand Up @@ -881,17 +909,20 @@ def get_search_query(
query_hits_limit,
SEARCH_RECAP_CHILD_HL_FIELDS,
get_child_sorting_key(cd, api_version),
api_version=api_version,
default_current_date=cd.get("request_date"),
)
match_all_parent_query = apply_custom_score_to_parent_query(
cd, Q("match", docket_child="docket"), api_version
match_all_parent_query = Q("match", docket_child="docket")
match_all_parent_query = nullify_query_score(
match_all_parent_query
)
return search_query.query(
Q(
"bool",
should=[match_all_child_query, match_all_parent_query],
)
final_match_all_query = Q(
"bool",
should=[match_all_child_query, match_all_parent_query],
)
final_match_all_query = apply_custom_score_to_parent_query(
cd, final_match_all_query, api_version
)
return search_query.query(final_match_all_query)
case SEARCH_TYPES.OPINION:
# Only return Opinion clusters.
q_should = [
Expand Down Expand Up @@ -1469,7 +1500,7 @@ def convert_str_date_fields_to_date_objects(
top_hits = result.grouped_by_opinion_cluster_id.hits.hits
for hit in top_hits:
date_str = hit["_source"][date_field_name]
date_obj = date.fromisoformat(date_str)
date_obj = datetime.date.fromisoformat(date_str)
hit["_source"][date_field_name] = date_obj


Expand Down Expand Up @@ -1539,7 +1570,9 @@ def fill_position_mapping(

if callable(field_value):
field_value = field_value()
elif isinstance(field_value, (datetime, date)):
elif isinstance(
field_value, (datetime.datetime, datetime.date)
):
field_value = midnight_pt(field_value)

mapping_dict[person_id].append(field_value)
Expand Down Expand Up @@ -2129,29 +2162,33 @@ def apply_custom_score_to_parent_query(
match cd["type"]:
case SEARCH_TYPES.RECAP | SEARCH_TYPES.DOCKETS if valid_child_order_by:
sort_field, order = child_order_by
if sort_field == "entry_date_filed":
# It applies a function score to the parent query to nullify
# the parent score (sets it to 0) to prioritize child documents
# sorting criteria. This will ensure that dockets without
# documents come last on results.
query = nullify_query_score(query)
elif sort_field == "dateFiled" and api_version:
if sort_field == "dateFiled" and api_version == "v4":
# Applies a custom function score to sort Dockets based on
# their dateFiled field. This serves as a workaround to enable
# the use of the search_after cursor for pagination on
# documents with a None dateFiled.
query = build_custom_function_score_for_date(
query, child_order_by, default_score=0
query,
child_order_by,
default_score=0,
default_current_date=cd["request_date"],
)

case SEARCH_TYPES.RECAP_DOCUMENT if valid_child_order_by:
sort_field, order = child_order_by
if sort_field in ["dateFiled", "entry_date_filed"] and api_version:
if (
sort_field in ["dateFiled", "entry_date_filed"]
and api_version == "v4"
):
# Applies a custom function score to sort RECAPDocuments based
# on their docket dateFiled or entry_date_filed field. This
# serves as a workaround to enable the use of the search_after
# cursor for pagination on documents with a None dateFiled.
query = build_custom_function_score_for_date(
query, child_order_by, default_score=0
query,
child_order_by,
default_score=0,
default_current_date=cd["request_date"],
)
return query

Expand Down Expand Up @@ -2279,7 +2316,7 @@ def build_full_join_es_queries(
hl_fields,
get_child_sorting_key(cd, api_version),
child_highlighting=child_highlighting,
api_version=api_version,
default_current_date=cd.get("request_date"),
)

if parties_filters and not has_child_query:
Expand All @@ -2293,7 +2330,7 @@ def build_full_join_es_queries(
query_hits_limit,
SEARCH_RECAP_CHILD_HL_FIELDS,
get_child_sorting_key(cd, api_version),
api_version=api_version,
default_current_date=cd.get("request_date"),
)

if has_child_query:
Expand Down Expand Up @@ -2346,19 +2383,21 @@ def build_full_join_es_queries(
minimum_should_match=1,
)
if parent_query:
parent_query = apply_custom_score_to_parent_query(
cd, parent_query, api_version
)
q_should.append(parent_query)

if not q_should:
return [], join_query

return (
final_query = apply_custom_score_to_parent_query(
cd,
Q(
"bool",
should=q_should,
),
api_version,
)
return (
final_query,
join_query,
)

Expand Down Expand Up @@ -2805,7 +2844,9 @@ def do_es_alert_estimation_query(
case _:
raise NotImplementedError

cd[after_field] = date.today() - timedelta(days=int(day_count))
cd[after_field] = datetime.date.today() - datetime.timedelta(
days=int(day_count)
)
cd[before_field] = None
estimation_query, _ = build_es_base_query(search_query, cd)

Expand Down
5 changes: 4 additions & 1 deletion cl/search/api_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def list(self, request, *args, **kwargs):
if search_form.is_valid():
cd = search_form.cleaned_data
search_type = cd["type"]
paginator = ESCursorPagination()
cd["request_date"] = paginator.initialize_context_from_request(
request, search_type
)
search_query = DocketDocument.search()
highlighting_fields = {}
main_query, child_docs_query = do_es_api_query(
Expand All @@ -241,7 +245,6 @@ def list(self, request, *args, **kwargs):
SEARCH_HL_TAG,
request.version,
)
paginator = ESCursorPagination()
es_list_instance = api_utils.CursorESList(
main_query,
child_docs_query,
Expand Down

0 comments on commit 266620b

Please sign in to comment.