Skip to content

Commit

Permalink
Add async supoort for SEARCH commands (#2096)
Browse files Browse the repository at this point in the history
* Add async supoort for SEARCH commands

* linters

* linters

* linters

* linters

* linters
  • Loading branch information
dvora-h committed Apr 28, 2022
1 parent c29d158 commit 1475e5c
Show file tree
Hide file tree
Showing 10 changed files with 6,865 additions and 5 deletions.
4 changes: 2 additions & 2 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
)
from redis.commands import (
AsyncCoreCommands,
AsyncRedisModuleCommands,
AsyncSentinelCommands,
RedisModuleCommands,
list_or_args,
)
from redis.compat import Protocol, TypedDict
Expand Down Expand Up @@ -81,7 +81,7 @@ async def __call__(self, response: Any, **kwargs):


class Redis(
AbstractRedis, RedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands
AbstractRedis, AsyncRedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands
):
"""
Implementation of the Redis protocol.
Expand Down
3 changes: 2 additions & 1 deletion redis/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .core import AsyncCoreCommands, CoreCommands
from .helpers import list_or_args
from .parser import CommandsParser
from .redismodules import RedisModuleCommands
from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands
from .sentinel import AsyncSentinelCommands, SentinelCommands

__all__ = [
Expand All @@ -11,6 +11,7 @@
"AsyncCoreCommands",
"CoreCommands",
"list_or_args",
"AsyncRedisModuleCommands",
"RedisModuleCommands",
"AsyncSentinelCommands",
"SentinelCommands",
Expand Down
10 changes: 10 additions & 0 deletions redis/commands/redismodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,13 @@ def graph(self, index_name="idx"):

g = Graph(client=self, name=index_name)
return g


class AsyncRedisModuleCommands(RedisModuleCommands):
def ft(self, index_name="idx"):
"""Access the search namespace, providing support for redis search."""

from .search import AsyncSearch

s = AsyncSearch(client=self, index_name=index_name)
return s
64 changes: 63 additions & 1 deletion redis/commands/search/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import redis

from .commands import SearchCommands
from .commands import AsyncSearchCommands, SearchCommands


class Search(SearchCommands):
Expand Down Expand Up @@ -112,5 +112,67 @@ def pipeline(self, transaction=True, shard_hint=None):
return p


class AsyncSearch(Search, AsyncSearchCommands):
class BatchIndexer(Search.BatchIndexer):
"""
A batch indexer allows you to automatically batch
document indexing in pipelines, flushing it every N documents.
"""

async def add_document(
self,
doc_id,
nosave=False,
score=1.0,
payload=None,
replace=False,
partial=False,
no_create=False,
**fields,
):
"""
Add a document to the batch query
"""
self.client._add_document(
doc_id,
conn=self._pipeline,
nosave=nosave,
score=score,
payload=payload,
replace=replace,
partial=partial,
no_create=no_create,
**fields,
)
self.current_chunk += 1
self.total += 1
if self.current_chunk >= self.chunk_size:
await self.commit()

async def commit(self):
"""
Manually commit and flush the batch indexing query
"""
await self._pipeline.execute()
self.current_chunk = 0

def pipeline(self, transaction=True, shard_hint=None):
"""Creates a pipeline for the SEARCH module, that can be used for executing
SEARCH commands, as well as classic core commands.
"""
p = AsyncPipeline(
connection_pool=self.client.connection_pool,
response_callbacks=self.MODULE_CALLBACKS,
transaction=transaction,
shard_hint=shard_hint,
)
p.index_name = self.index_name
return p


class Pipeline(SearchCommands, redis.client.Pipeline):
"""Pipeline for the module."""


class AsyncPipeline(AsyncSearchCommands, redis.asyncio.client.Pipeline):
"""AsyncPipeline for the module."""
242 changes: 242 additions & 0 deletions redis/commands/search/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,3 +857,245 @@ def syndump(self):
""" # noqa
raw = self.execute_command(SYNDUMP_CMD, self.index_name)
return {raw[i]: raw[i + 1] for i in range(0, len(raw), 2)}


class AsyncSearchCommands(SearchCommands):
async def info(self):
"""
Get info an stats about the the current index, including the number of
documents, memory consumption, etc
For more information https://oss.redis.com/redisearch/Commands/#ftinfo
"""

res = await self.execute_command(INFO_CMD, self.index_name)
it = map(to_string, res)
return dict(zip(it, it))

async def search(
self,
query: Union[str, Query],
query_params: Dict[str, Union[str, int, float]] = None,
):
"""
Search the index for a given query, and return a result of documents
### Parameters
- **query**: the search query. Either a text for simple queries with
default parameters, or a Query object for complex queries.
See RediSearch's documentation on query format
For more information: https://oss.redis.com/redisearch/Commands/#ftsearch
""" # noqa
args, query = self._mk_query_args(query, query_params=query_params)
st = time.time()
res = await self.execute_command(SEARCH_CMD, *args)

if isinstance(res, Pipeline):
return res

return Result(
res,
not query._no_content,
duration=(time.time() - st) * 1000.0,
has_payload=query._with_payloads,
with_scores=query._with_scores,
)

async def aggregate(
self,
query: Union[str, Query],
query_params: Dict[str, Union[str, int, float]] = None,
):
"""
Issue an aggregation query.
### Parameters
**query**: This can be either an `AggregateRequest`, or a `Cursor`
An `AggregateResult` object is returned. You can access the rows from
its `rows` property, which will always yield the rows of the result.
For more information: https://oss.redis.com/redisearch/Commands/#ftaggregate
""" # noqa
if isinstance(query, AggregateRequest):
has_cursor = bool(query._cursor)
cmd = [AGGREGATE_CMD, self.index_name] + query.build_args()
elif isinstance(query, Cursor):
has_cursor = True
cmd = [CURSOR_CMD, "READ", self.index_name] + query.build_args()
else:
raise ValueError("Bad query", query)
if query_params is not None:
cmd += self.get_params_args(query_params)

raw = await self.execute_command(*cmd)
return self._get_aggregate_result(raw, query, has_cursor)

async def spellcheck(self, query, distance=None, include=None, exclude=None):
"""
Issue a spellcheck query
### Parameters
**query**: search query.
**distance***: the maximal Levenshtein distance for spelling
suggestions (default: 1, max: 4).
**include**: specifies an inclusion custom dictionary.
**exclude**: specifies an exclusion custom dictionary.
For more information: https://oss.redis.com/redisearch/Commands/#ftspellcheck
""" # noqa
cmd = [SPELLCHECK_CMD, self.index_name, query]
if distance:
cmd.extend(["DISTANCE", distance])

if include:
cmd.extend(["TERMS", "INCLUDE", include])

if exclude:
cmd.extend(["TERMS", "EXCLUDE", exclude])

raw = await self.execute_command(*cmd)

corrections = {}
if raw == 0:
return corrections

for _correction in raw:
if isinstance(_correction, int) and _correction == 0:
continue

if len(_correction) != 3:
continue
if not _correction[2]:
continue
if not _correction[2][0]:
continue

corrections[_correction[1]] = [
{"score": _item[0], "suggestion": _item[1]} for _item in _correction[2]
]

return corrections

async def config_set(self, option, value):
"""Set runtime configuration option.
### Parameters
- **option**: the name of the configuration option.
- **value**: a value for the configuration option.
For more information: https://oss.redis.com/redisearch/Commands/#ftconfig
""" # noqa
cmd = [CONFIG_CMD, "SET", option, value]
raw = await self.execute_command(*cmd)
return raw == "OK"

async def config_get(self, option):
"""Get runtime configuration option value.
### Parameters
- **option**: the name of the configuration option.
For more information: https://oss.redis.com/redisearch/Commands/#ftconfig
""" # noqa
cmd = [CONFIG_CMD, "GET", option]
res = {}
raw = await self.execute_command(*cmd)
if raw:
for kvs in raw:
res[kvs[0]] = kvs[1]
return res

async def load_document(self, id):
"""
Load a single document by id
"""
fields = await self.client.hgetall(id)
f2 = {to_string(k): to_string(v) for k, v in fields.items()}
fields = f2

try:
del fields["id"]
except KeyError:
pass

return Document(id=id, **fields)

async def sugadd(self, key, *suggestions, **kwargs):
"""
Add suggestion terms to the AutoCompleter engine. Each suggestion has
a score and string.
If kwargs["increment"] is true and the terms are already in the
server's dictionary, we increment their scores.
For more information: https://oss.redis.com/redisearch/master/Commands/#ftsugadd
""" # noqa
# If Transaction is not False it will MULTI/EXEC which will error
pipe = self.pipeline(transaction=False)
for sug in suggestions:
args = [SUGADD_COMMAND, key, sug.string, sug.score]
if kwargs.get("increment"):
args.append("INCR")
if sug.payload:
args.append("PAYLOAD")
args.append(sug.payload)

pipe.execute_command(*args)

return (await pipe.execute())[-1]

async def sugget(
self, key, prefix, fuzzy=False, num=10, with_scores=False, with_payloads=False
):
"""
Get a list of suggestions from the AutoCompleter, for a given prefix.
Parameters:
prefix : str
The prefix we are searching. **Must be valid ascii or utf-8**
fuzzy : bool
If set to true, the prefix search is done in fuzzy mode.
**NOTE**: Running fuzzy searches on short (<3 letters) prefixes
can be very
slow, and even scan the entire index.
with_scores : bool
If set to true, we also return the (refactored) score of
each suggestion.
This is normally not needed, and is NOT the original score
inserted into the index.
with_payloads : bool
Return suggestion payloads
num : int
The maximum number of results we return. Note that we might
return less. The algorithm trims irrelevant suggestions.
Returns:
list:
A list of Suggestion objects. If with_scores was False, the
score of all suggestions is 1.
For more information: https://oss.redis.com/redisearch/master/Commands/#ftsugget
""" # noqa
args = [SUGGET_COMMAND, key, prefix, "MAX", num]
if fuzzy:
args.append(FUZZY)
if with_scores:
args.append(WITHSCORES)
if with_payloads:
args.append(WITHPAYLOADS)

ret = await self.execute_command(*args)
results = []
if not ret:
return results

parser = SuggestionParser(with_scores, with_payloads, ret)
return [s for s in parser]
2 changes: 1 addition & 1 deletion tests/test_asyncio/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def wait_for_command(
# generate key
redis_version = REDIS_INFO["version"]
if Version(redis_version) >= Version("5.0.0"):
id_str = str(client.client_id())
id_str = str(await client.client_id())
else:
id_str = f"{random.randrange(2 ** 32):08x}"
key = f"__REDIS-PY-{id_str}__"
Expand Down

0 comments on commit 1475e5c

Please sign in to comment.