fix
All checks were successful
Deploy Prod / Build (pull_request) Successful in 9s
Deploy Prod / Push (pull_request) Successful in 12s
Deploy Prod / Deploy prod (pull_request) Successful in 10s

This commit is contained in:
Egor Matveev
2024-12-28 22:48:16 +03:00
parent c1249bfcd0
commit 6c6a549aff
2532 changed files with 562109 additions and 1 deletions

View File

@@ -0,0 +1,189 @@
import redis
from ...asyncio.client import Pipeline as AsyncioPipeline
from .commands import (
AGGREGATE_CMD,
CONFIG_CMD,
INFO_CMD,
PROFILE_CMD,
SEARCH_CMD,
SPELLCHECK_CMD,
SYNDUMP_CMD,
AsyncSearchCommands,
SearchCommands,
)
class Search(SearchCommands):
"""
Create a client for talking to search.
It abstracts the API of the module and lets you just use the engine.
"""
class BatchIndexer:
"""
A batch indexer allows you to automatically batch
document indexing in pipelines, flushing it every N documents.
"""
def __init__(self, client, chunk_size=1000):
self.client = client
self.execute_command = client.execute_command
self._pipeline = client.pipeline(transaction=False, shard_hint=None)
self.total = 0
self.chunk_size = chunk_size
self.current_chunk = 0
def __del__(self):
if self.current_chunk:
self.commit()
def add_document(
self,
doc_id,
nosave=False,
score=1.0,
payload=None,
replace=False,
partial=False,
no_create=False,
**fields,
):
"""
Add a document to the batch query
"""
self.client._add_document(
doc_id,
conn=self._pipeline,
nosave=nosave,
score=score,
payload=payload,
replace=replace,
partial=partial,
no_create=no_create,
**fields,
)
self.current_chunk += 1
self.total += 1
if self.current_chunk >= self.chunk_size:
self.commit()
def add_document_hash(self, doc_id, score=1.0, replace=False):
"""
Add a hash to the batch query
"""
self.client._add_document_hash(
doc_id, conn=self._pipeline, score=score, replace=replace
)
self.current_chunk += 1
self.total += 1
if self.current_chunk >= self.chunk_size:
self.commit()
def commit(self):
"""
Manually commit and flush the batch indexing query
"""
self._pipeline.execute()
self.current_chunk = 0
def __init__(self, client, index_name="idx"):
"""
Create a new Client for the given index_name.
The default name is `idx`
If conn is not None, we employ an already existing redis connection
"""
self._MODULE_CALLBACKS = {}
self.client = client
self.index_name = index_name
self.execute_command = client.execute_command
self._pipeline = client.pipeline
self._RESP2_MODULE_CALLBACKS = {
INFO_CMD: self._parse_info,
SEARCH_CMD: self._parse_search,
AGGREGATE_CMD: self._parse_aggregate,
PROFILE_CMD: self._parse_profile,
SPELLCHECK_CMD: self._parse_spellcheck,
CONFIG_CMD: self._parse_config_get,
SYNDUMP_CMD: self._parse_syndump,
}
def pipeline(self, transaction=True, shard_hint=None):
"""Creates a pipeline for the SEARCH module, that can be used for executing
SEARCH commands, as well as classic core commands.
"""
p = Pipeline(
connection_pool=self.client.connection_pool,
response_callbacks=self._MODULE_CALLBACKS,
transaction=transaction,
shard_hint=shard_hint,
)
p.index_name = self.index_name
return p
class AsyncSearch(Search, AsyncSearchCommands):
class BatchIndexer(Search.BatchIndexer):
"""
A batch indexer allows you to automatically batch
document indexing in pipelines, flushing it every N documents.
"""
async def add_document(
self,
doc_id,
nosave=False,
score=1.0,
payload=None,
replace=False,
partial=False,
no_create=False,
**fields,
):
"""
Add a document to the batch query
"""
self.client._add_document(
doc_id,
conn=self._pipeline,
nosave=nosave,
score=score,
payload=payload,
replace=replace,
partial=partial,
no_create=no_create,
**fields,
)
self.current_chunk += 1
self.total += 1
if self.current_chunk >= self.chunk_size:
await self.commit()
async def commit(self):
"""
Manually commit and flush the batch indexing query
"""
await self._pipeline.execute()
self.current_chunk = 0
def pipeline(self, transaction=True, shard_hint=None):
"""Creates a pipeline for the SEARCH module, that can be used for executing
SEARCH commands, as well as classic core commands.
"""
p = AsyncPipeline(
connection_pool=self.client.connection_pool,
response_callbacks=self._MODULE_CALLBACKS,
transaction=transaction,
shard_hint=shard_hint,
)
p.index_name = self.index_name
return p
class Pipeline(SearchCommands, redis.client.Pipeline):
"""Pipeline for the module."""
class AsyncPipeline(AsyncSearchCommands, AsyncioPipeline, Pipeline):
"""AsyncPipeline for the module."""

View File

@@ -0,0 +1,7 @@
def to_string(s, encoding: str = "utf-8"):
if isinstance(s, str):
return s
elif isinstance(s, bytes):
return s.decode(encoding, "ignore")
else:
return s # Not a string we care about

View File

@@ -0,0 +1,399 @@
from typing import List, Union
FIELDNAME = object()
class Limit:
def __init__(self, offset: int = 0, count: int = 0) -> None:
self.offset = offset
self.count = count
def build_args(self):
if self.count:
return ["LIMIT", str(self.offset), str(self.count)]
else:
return []
class Reducer:
"""
Base reducer object for all reducers.
See the `redisearch.reducers` module for the actual reducers.
"""
NAME = None
def __init__(self, *args: List[str]) -> None:
self._args = args
self._field = None
self._alias = None
def alias(self, alias: str) -> "Reducer":
"""
Set the alias for this reducer.
### Parameters
- **alias**: The value of the alias for this reducer. If this is the
special value `aggregation.FIELDNAME` then this reducer will be
aliased using the same name as the field upon which it operates.
Note that using `FIELDNAME` is only possible on reducers which
operate on a single field value.
This method returns the `Reducer` object making it suitable for
chaining.
"""
if alias is FIELDNAME:
if not self._field:
raise ValueError("Cannot use FIELDNAME alias with no field")
# Chop off initial '@'
alias = self._field[1:]
self._alias = alias
return self
@property
def args(self) -> List[str]:
return self._args
class SortDirection:
"""
This special class is used to indicate sort direction.
"""
DIRSTRING = None
def __init__(self, field: str) -> None:
self.field = field
class Asc(SortDirection):
"""
Indicate that the given field should be sorted in ascending order
"""
DIRSTRING = "ASC"
class Desc(SortDirection):
"""
Indicate that the given field should be sorted in descending order
"""
DIRSTRING = "DESC"
class AggregateRequest:
"""
Aggregation request which can be passed to `Client.aggregate`.
"""
def __init__(self, query: str = "*") -> None:
"""
Create an aggregation request. This request may then be passed to
`client.aggregate()`.
In order for the request to be usable, it must contain at least one
group.
- **query** Query string for filtering records.
All member methods (except `build_args()`)
return the object itself, making them useful for chaining.
"""
self._query = query
self._aggregateplan = []
self._loadfields = []
self._loadall = False
self._max = 0
self._with_schema = False
self._verbatim = False
self._cursor = []
self._dialect = None
self._add_scores = False
self._scorer = "TFIDF"
def load(self, *fields: List[str]) -> "AggregateRequest":
"""
Indicate the fields to be returned in the response. These fields are
returned in addition to any others implicitly specified.
### Parameters
- **fields**: If fields not specified, all the fields will be loaded.
Otherwise, fields should be given in the format of `@field`.
"""
if fields:
self._loadfields.extend(fields)
else:
self._loadall = True
return self
def group_by(
self, fields: List[str], *reducers: Union[Reducer, List[Reducer]]
) -> "AggregateRequest":
"""
Specify by which fields to group the aggregation.
### Parameters
- **fields**: Fields to group by. This can either be a single string,
or a list of strings. both cases, the field should be specified as
`@field`.
- **reducers**: One or more reducers. Reducers may be found in the
`aggregation` module.
"""
fields = [fields] if isinstance(fields, str) else fields
reducers = [reducers] if isinstance(reducers, Reducer) else reducers
ret = ["GROUPBY", str(len(fields)), *fields]
for reducer in reducers:
ret += ["REDUCE", reducer.NAME, str(len(reducer.args))]
ret.extend(reducer.args)
if reducer._alias is not None:
ret += ["AS", reducer._alias]
self._aggregateplan.extend(ret)
return self
def apply(self, **kwexpr) -> "AggregateRequest":
"""
Specify one or more projection expressions to add to each result
### Parameters
- **kwexpr**: One or more key-value pairs for a projection. The key is
the alias for the projection, and the value is the projection
expression itself, for example `apply(square_root="sqrt(@foo)")`
"""
for alias, expr in kwexpr.items():
ret = ["APPLY", expr]
if alias is not None:
ret += ["AS", alias]
self._aggregateplan.extend(ret)
return self
def limit(self, offset: int, num: int) -> "AggregateRequest":
"""
Sets the limit for the most recent group or query.
If no group has been defined yet (via `group_by()`) then this sets
the limit for the initial pool of results from the query. Otherwise,
this limits the number of items operated on from the previous group.
Setting a limit on the initial search results may be useful when
attempting to execute an aggregation on a sample of a large data set.
### Parameters
- **offset**: Result offset from which to begin paging
- **num**: Number of results to return
Example of sorting the initial results:
```
AggregateRequest("@sale_amount:[10000, inf]")\
.limit(0, 10)\
.group_by("@state", r.count())
```
Will only group by the states found in the first 10 results of the
query `@sale_amount:[10000, inf]`. On the other hand,
```
AggregateRequest("@sale_amount:[10000, inf]")\
.limit(0, 1000)\
.group_by("@state", r.count()\
.limit(0, 10)
```
Will group all the results matching the query, but only return the
first 10 groups.
If you only wish to return a *top-N* style query, consider using
`sort_by()` instead.
"""
_limit = Limit(offset, num)
self._aggregateplan.extend(_limit.build_args())
return self
def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest":
"""
Indicate how the results should be sorted. This can also be used for
*top-N* style queries
### Parameters
- **fields**: The fields by which to sort. This can be either a single
field or a list of fields. If you wish to specify order, you can
use the `Asc` or `Desc` wrapper classes.
- **max**: Maximum number of results to return. This can be
used instead of `LIMIT` and is also faster.
Example of sorting by `foo` ascending and `bar` descending:
```
sort_by(Asc("@foo"), Desc("@bar"))
```
Return the top 10 customers:
```
AggregateRequest()\
.group_by("@customer", r.sum("@paid").alias(FIELDNAME))\
.sort_by(Desc("@paid"), max=10)
```
"""
if isinstance(fields, (str, SortDirection)):
fields = [fields]
fields_args = []
for f in fields:
if isinstance(f, SortDirection):
fields_args += [f.field, f.DIRSTRING]
else:
fields_args += [f]
ret = ["SORTBY", str(len(fields_args))]
ret.extend(fields_args)
max = kwargs.get("max", 0)
if max > 0:
ret += ["MAX", str(max)]
self._aggregateplan.extend(ret)
return self
def filter(self, expressions: Union[str, List[str]]) -> "AggregateRequest":
"""
Specify filter for post-query results using predicates relating to
values in the result set.
### Parameters
- **fields**: Fields to group by. This can either be a single string,
or a list of strings.
"""
if isinstance(expressions, str):
expressions = [expressions]
for expression in expressions:
self._aggregateplan.extend(["FILTER", expression])
return self
def with_schema(self) -> "AggregateRequest":
"""
If set, the `schema` property will contain a list of `[field, type]`
entries in the result object.
"""
self._with_schema = True
return self
def add_scores(self) -> "AggregateRequest":
"""
If set, includes the score as an ordinary field of the row.
"""
self._add_scores = True
return self
def scorer(self, scorer: str) -> "AggregateRequest":
"""
Use a different scoring function to evaluate document relevance.
Default is `TFIDF`.
:param scorer: The scoring function to use
(e.g. `TFIDF.DOCNORM` or `BM25`)
"""
self._scorer = scorer
return self
def verbatim(self) -> "AggregateRequest":
self._verbatim = True
return self
def cursor(self, count: int = 0, max_idle: float = 0.0) -> "AggregateRequest":
args = ["WITHCURSOR"]
if count:
args += ["COUNT", str(count)]
if max_idle:
args += ["MAXIDLE", str(max_idle * 1000)]
self._cursor = args
return self
def build_args(self) -> List[str]:
# @foo:bar ...
ret = [self._query]
if self._with_schema:
ret.append("WITHSCHEMA")
if self._verbatim:
ret.append("VERBATIM")
if self._scorer:
ret.extend(["SCORER", self._scorer])
if self._add_scores:
ret.append("ADDSCORES")
if self._cursor:
ret += self._cursor
if self._loadall:
ret.append("LOAD")
ret.append("*")
elif self._loadfields:
ret.append("LOAD")
ret.append(str(len(self._loadfields)))
ret.extend(self._loadfields)
if self._dialect:
ret.extend(["DIALECT", self._dialect])
ret.extend(self._aggregateplan)
return ret
def dialect(self, dialect: int) -> "AggregateRequest":
"""
Add a dialect field to the aggregate command.
- **dialect** - dialect version to execute the query under
"""
self._dialect = dialect
return self
class Cursor:
def __init__(self, cid: int) -> None:
self.cid = cid
self.max_idle = 0
self.count = 0
def build_args(self):
args = [str(self.cid)]
if self.max_idle:
args += ["MAXIDLE", str(self.max_idle)]
if self.count:
args += ["COUNT", str(self.count)]
return args
class AggregateResult:
def __init__(self, rows, cursor: Cursor, schema) -> None:
self.rows = rows
self.cursor = cursor
self.schema = schema
def __repr__(self) -> (str, str):
cid = self.cursor.cid if self.cursor else -1
return (
f"<{self.__class__.__name__} at 0x{id(self):x} "
f"Rows={len(self.rows)}, Cursor={cid}>"
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
class Document:
"""
Represents a single document in a result set
"""
def __init__(self, id, payload=None, **fields):
self.id = id
self.payload = payload
for k, v in fields.items():
setattr(self, k, v)
def __repr__(self):
return f"Document {self.__dict__}"
def __getitem__(self, item):
value = getattr(self, item)
return value

View File

@@ -0,0 +1,210 @@
from typing import List
from redis import DataError
class Field:
"""
A class representing a field in a document.
"""
NUMERIC = "NUMERIC"
TEXT = "TEXT"
WEIGHT = "WEIGHT"
GEO = "GEO"
TAG = "TAG"
VECTOR = "VECTOR"
SORTABLE = "SORTABLE"
NOINDEX = "NOINDEX"
AS = "AS"
GEOSHAPE = "GEOSHAPE"
INDEX_MISSING = "INDEXMISSING"
INDEX_EMPTY = "INDEXEMPTY"
def __init__(
self,
name: str,
args: List[str] = None,
sortable: bool = False,
no_index: bool = False,
index_missing: bool = False,
index_empty: bool = False,
as_name: str = None,
):
"""
Create a new field object.
Args:
name: The name of the field.
args:
sortable: If `True`, the field will be sortable.
no_index: If `True`, the field will not be indexed.
index_missing: If `True`, it will be possible to search for documents that
have this field missing.
index_empty: If `True`, it will be possible to search for documents that
have this field empty.
as_name: If provided, this alias will be used for the field.
"""
if args is None:
args = []
self.name = name
self.args = args
self.args_suffix = list()
self.as_name = as_name
if sortable:
self.args_suffix.append(Field.SORTABLE)
if no_index:
self.args_suffix.append(Field.NOINDEX)
if index_missing:
self.args_suffix.append(Field.INDEX_MISSING)
if index_empty:
self.args_suffix.append(Field.INDEX_EMPTY)
if no_index and not sortable:
raise ValueError("Non-Sortable non-Indexable fields are ignored")
def append_arg(self, value):
self.args.append(value)
def redis_args(self):
args = [self.name]
if self.as_name:
args += [self.AS, self.as_name]
args += self.args
args += self.args_suffix
return args
class TextField(Field):
"""
TextField is used to define a text field in a schema definition
"""
NOSTEM = "NOSTEM"
PHONETIC = "PHONETIC"
def __init__(
self,
name: str,
weight: float = 1.0,
no_stem: bool = False,
phonetic_matcher: str = None,
withsuffixtrie: bool = False,
**kwargs,
):
Field.__init__(self, name, args=[Field.TEXT, Field.WEIGHT, weight], **kwargs)
if no_stem:
Field.append_arg(self, self.NOSTEM)
if phonetic_matcher and phonetic_matcher in [
"dm:en",
"dm:fr",
"dm:pt",
"dm:es",
]:
Field.append_arg(self, self.PHONETIC)
Field.append_arg(self, phonetic_matcher)
if withsuffixtrie:
Field.append_arg(self, "WITHSUFFIXTRIE")
class NumericField(Field):
"""
NumericField is used to define a numeric field in a schema definition
"""
def __init__(self, name: str, **kwargs):
Field.__init__(self, name, args=[Field.NUMERIC], **kwargs)
class GeoShapeField(Field):
"""
GeoShapeField is used to enable within/contain indexing/searching
"""
SPHERICAL = "SPHERICAL"
FLAT = "FLAT"
def __init__(self, name: str, coord_system=None, **kwargs):
args = [Field.GEOSHAPE]
if coord_system:
args.append(coord_system)
Field.__init__(self, name, args=args, **kwargs)
class GeoField(Field):
"""
GeoField is used to define a geo-indexing field in a schema definition
"""
def __init__(self, name: str, **kwargs):
Field.__init__(self, name, args=[Field.GEO], **kwargs)
class TagField(Field):
"""
TagField is a tag-indexing field with simpler compression and tokenization.
See http://redisearch.io/Tags/
"""
SEPARATOR = "SEPARATOR"
CASESENSITIVE = "CASESENSITIVE"
def __init__(
self,
name: str,
separator: str = ",",
case_sensitive: bool = False,
withsuffixtrie: bool = False,
**kwargs,
):
args = [Field.TAG, self.SEPARATOR, separator]
if case_sensitive:
args.append(self.CASESENSITIVE)
if withsuffixtrie:
args.append("WITHSUFFIXTRIE")
Field.__init__(self, name, args=args, **kwargs)
class VectorField(Field):
"""
Allows vector similarity queries against the value in this attribute.
See https://oss.redis.com/redisearch/Vectors/#vector_fields.
"""
def __init__(self, name: str, algorithm: str, attributes: dict, **kwargs):
"""
Create Vector Field. Notice that Vector cannot have sortable or no_index tag,
although it's also a Field.
``name`` is the name of the field.
``algorithm`` can be "FLAT" or "HNSW".
``attributes`` each algorithm can have specific attributes. Some of them
are mandatory and some of them are optional. See
https://oss.redis.com/redisearch/master/Vectors/#specific_creation_attributes_per_algorithm
for more information.
"""
sort = kwargs.get("sortable", False)
noindex = kwargs.get("no_index", False)
if sort or noindex:
raise DataError("Cannot set 'sortable' or 'no_index' in Vector fields.")
if algorithm.upper() not in ["FLAT", "HNSW"]:
raise DataError(
"Realtime vector indexing supporting 2 Indexing Methods:"
"'FLAT' and 'HNSW'."
)
attr_li = []
for key, value in attributes.items():
attr_li.extend([key, value])
Field.__init__(
self, name, args=[Field.VECTOR, algorithm, len(attr_li), *attr_li], **kwargs
)

View File

@@ -0,0 +1,79 @@
from enum import Enum
class IndexType(Enum):
"""Enum of the currently supported index types."""
HASH = 1
JSON = 2
class IndexDefinition:
"""IndexDefinition is used to define a index definition for automatic
indexing on Hash or Json update."""
def __init__(
self,
prefix=[],
filter=None,
language_field=None,
language=None,
score_field=None,
score=1.0,
payload_field=None,
index_type=None,
):
self.args = []
self._append_index_type(index_type)
self._append_prefix(prefix)
self._append_filter(filter)
self._append_language(language_field, language)
self._append_score(score_field, score)
self._append_payload(payload_field)
def _append_index_type(self, index_type):
"""Append `ON HASH` or `ON JSON` according to the enum."""
if index_type is IndexType.HASH:
self.args.extend(["ON", "HASH"])
elif index_type is IndexType.JSON:
self.args.extend(["ON", "JSON"])
elif index_type is not None:
raise RuntimeError(f"index_type must be one of {list(IndexType)}")
def _append_prefix(self, prefix):
"""Append PREFIX."""
if len(prefix) > 0:
self.args.append("PREFIX")
self.args.append(len(prefix))
for p in prefix:
self.args.append(p)
def _append_filter(self, filter):
"""Append FILTER."""
if filter is not None:
self.args.append("FILTER")
self.args.append(filter)
def _append_language(self, language_field, language):
"""Append LANGUAGE_FIELD and LANGUAGE."""
if language_field is not None:
self.args.append("LANGUAGE_FIELD")
self.args.append(language_field)
if language is not None:
self.args.append("LANGUAGE")
self.args.append(language)
def _append_score(self, score_field, score):
"""Append SCORE_FIELD and SCORE."""
if score_field is not None:
self.args.append("SCORE_FIELD")
self.args.append(score_field)
if score is not None:
self.args.append("SCORE")
self.args.append(score)
def _append_payload(self, payload_field):
"""Append PAYLOAD_FIELD."""
if payload_field is not None:
self.args.append("PAYLOAD_FIELD")
self.args.append(payload_field)

View File

@@ -0,0 +1,377 @@
from typing import List, Optional, Union
class Query:
"""
Query is used to build complex queries that have more parameters than just
the query string. The query string is set in the constructor, and other
options have setter functions.
The setter functions return the query object, so they can be chained,
i.e. `Query("foo").verbatim().filter(...)` etc.
"""
def __init__(self, query_string: str) -> None:
"""
Create a new query object.
The query string is set in the constructor, and other options have
setter functions.
"""
self._query_string: str = query_string
self._offset: int = 0
self._num: int = 10
self._no_content: bool = False
self._no_stopwords: bool = False
self._fields: Optional[List[str]] = None
self._verbatim: bool = False
self._with_payloads: bool = False
self._with_scores: bool = False
self._scorer: Optional[str] = None
self._filters: List = list()
self._ids: Optional[List[str]] = None
self._slop: int = -1
self._timeout: Optional[float] = None
self._in_order: bool = False
self._sortby: Optional[SortbyField] = None
self._return_fields: List = []
self._return_fields_decode_as: dict = {}
self._summarize_fields: List = []
self._highlight_fields: List = []
self._language: Optional[str] = None
self._expander: Optional[str] = None
self._dialect: Optional[int] = None
def query_string(self) -> str:
"""Return the query string of this query only."""
return self._query_string
def limit_ids(self, *ids) -> "Query":
"""Limit the results to a specific set of pre-known document
ids of any length."""
self._ids = ids
return self
def return_fields(self, *fields) -> "Query":
"""Add fields to return fields."""
for field in fields:
self.return_field(field)
return self
def return_field(
self,
field: str,
as_field: Optional[str] = None,
decode_field: Optional[bool] = True,
encoding: Optional[str] = "utf8",
) -> "Query":
"""
Add a field to the list of fields to return.
- **field**: The field to include in query results
- **as_field**: The alias for the field
- **decode_field**: Whether to decode the field from bytes to string
- **encoding**: The encoding to use when decoding the field
"""
self._return_fields.append(field)
self._return_fields_decode_as[field] = encoding if decode_field else None
if as_field is not None:
self._return_fields += ("AS", as_field)
return self
def _mk_field_list(self, fields: List[str]) -> List:
if not fields:
return []
return [fields] if isinstance(fields, str) else list(fields)
def summarize(
self,
fields: Optional[List] = None,
context_len: Optional[int] = None,
num_frags: Optional[int] = None,
sep: Optional[str] = None,
) -> "Query":
"""
Return an abridged format of the field, containing only the segments of
the field which contain the matching term(s).
If `fields` is specified, then only the mentioned fields are
summarized; otherwise all results are summarized.
Server side defaults are used for each option (except `fields`)
if not specified
- **fields** List of fields to summarize. All fields are summarized
if not specified
- **context_len** Amount of context to include with each fragment
- **num_frags** Number of fragments per document
- **sep** Separator string to separate fragments
"""
args = ["SUMMARIZE"]
fields = self._mk_field_list(fields)
if fields:
args += ["FIELDS", str(len(fields))] + fields
if context_len is not None:
args += ["LEN", str(context_len)]
if num_frags is not None:
args += ["FRAGS", str(num_frags)]
if sep is not None:
args += ["SEPARATOR", sep]
self._summarize_fields = args
return self
def highlight(
self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None
) -> None:
"""
Apply specified markup to matched term(s) within the returned field(s).
- **fields** If specified then only those mentioned fields are
highlighted, otherwise all fields are highlighted
- **tags** A list of two strings to surround the match.
"""
args = ["HIGHLIGHT"]
fields = self._mk_field_list(fields)
if fields:
args += ["FIELDS", str(len(fields))] + fields
if tags:
args += ["TAGS"] + list(tags)
self._highlight_fields = args
return self
def language(self, language: str) -> "Query":
"""
Analyze the query as being in the specified language.
:param language: The language (e.g. `chinese` or `english`)
"""
self._language = language
return self
def slop(self, slop: int) -> "Query":
"""Allow a maximum of N intervening non matched terms between
phrase terms (0 means exact phrase).
"""
self._slop = slop
return self
def timeout(self, timeout: float) -> "Query":
"""overrides the timeout parameter of the module"""
self._timeout = timeout
return self
def in_order(self) -> "Query":
"""
Match only documents where the query terms appear in
the same order in the document.
i.e. for the query "hello world", we do not match "world hello"
"""
self._in_order = True
return self
def scorer(self, scorer: str) -> "Query":
"""
Use a different scoring function to evaluate document relevance.
Default is `TFIDF`.
:param scorer: The scoring function to use
(e.g. `TFIDF.DOCNORM` or `BM25`)
"""
self._scorer = scorer
return self
def get_args(self) -> List[str]:
"""Format the redis arguments for this query and return them."""
args = [self._query_string]
args += self._get_args_tags()
args += self._summarize_fields + self._highlight_fields
args += ["LIMIT", self._offset, self._num]
return args
def _get_args_tags(self) -> List[str]:
args = []
if self._no_content:
args.append("NOCONTENT")
if self._fields:
args.append("INFIELDS")
args.append(len(self._fields))
args += self._fields
if self._verbatim:
args.append("VERBATIM")
if self._no_stopwords:
args.append("NOSTOPWORDS")
if self._filters:
for flt in self._filters:
if not isinstance(flt, Filter):
raise AttributeError("Did not receive a Filter object.")
args += flt.args
if self._with_payloads:
args.append("WITHPAYLOADS")
if self._scorer:
args += ["SCORER", self._scorer]
if self._with_scores:
args.append("WITHSCORES")
if self._ids:
args.append("INKEYS")
args.append(len(self._ids))
args += self._ids
if self._slop >= 0:
args += ["SLOP", self._slop]
if self._timeout is not None:
args += ["TIMEOUT", self._timeout]
if self._in_order:
args.append("INORDER")
if self._return_fields:
args.append("RETURN")
args.append(len(self._return_fields))
args += self._return_fields
if self._sortby:
if not isinstance(self._sortby, SortbyField):
raise AttributeError("Did not receive a SortByField.")
args.append("SORTBY")
args += self._sortby.args
if self._language:
args += ["LANGUAGE", self._language]
if self._expander:
args += ["EXPANDER", self._expander]
if self._dialect:
args += ["DIALECT", self._dialect]
return args
def paging(self, offset: int, num: int) -> "Query":
"""
Set the paging for the query (defaults to 0..10).
- **offset**: Paging offset for the results. Defaults to 0
- **num**: How many results do we want
"""
self._offset = offset
self._num = num
return self
def verbatim(self) -> "Query":
"""Set the query to be verbatim, i.e. use no query expansion
or stemming.
"""
self._verbatim = True
return self
def no_content(self) -> "Query":
"""Set the query to only return ids and not the document content."""
self._no_content = True
return self
def no_stopwords(self) -> "Query":
"""
Prevent the query from being filtered for stopwords.
Only useful in very big queries that you are certain contain
no stopwords.
"""
self._no_stopwords = True
return self
def with_payloads(self) -> "Query":
"""Ask the engine to return document payloads."""
self._with_payloads = True
return self
def with_scores(self) -> "Query":
"""Ask the engine to return document search scores."""
self._with_scores = True
return self
def limit_fields(self, *fields: List[str]) -> "Query":
"""
Limit the search to specific TEXT fields only.
- **fields**: A list of strings, case sensitive field names
from the defined schema.
"""
self._fields = fields
return self
def add_filter(self, flt: "Filter") -> "Query":
"""
Add a numeric or geo filter to the query.
**Currently only one of each filter is supported by the engine**
- **flt**: A NumericFilter or GeoFilter object, used on a
corresponding field
"""
self._filters.append(flt)
return self
def sort_by(self, field: str, asc: bool = True) -> "Query":
"""
Add a sortby field to the query.
- **field** - the name of the field to sort by
- **asc** - when `True`, sorting will be done in asceding order
"""
self._sortby = SortbyField(field, asc)
return self
def expander(self, expander: str) -> "Query":
"""
Add a expander field to the query.
- **expander** - the name of the expander
"""
self._expander = expander
return self
def dialect(self, dialect: int) -> "Query":
"""
Add a dialect field to the query.
- **dialect** - dialect version to execute the query under
"""
self._dialect = dialect
return self
class Filter:
def __init__(self, keyword: str, field: str, *args: List[str]) -> None:
self.args = [keyword, field] + list(args)
class NumericFilter(Filter):
INF = "+inf"
NEG_INF = "-inf"
def __init__(
self,
field: str,
minval: Union[int, str],
maxval: Union[int, str],
minExclusive: bool = False,
maxExclusive: bool = False,
) -> None:
args = [
minval if not minExclusive else f"({minval}",
maxval if not maxExclusive else f"({maxval}",
]
Filter.__init__(self, "FILTER", field, *args)
class GeoFilter(Filter):
METERS = "m"
KILOMETERS = "km"
FEET = "ft"
MILES = "mi"
def __init__(
self, field: str, lon: float, lat: float, radius: float, unit: str = KILOMETERS
) -> None:
Filter.__init__(self, "GEOFILTER", field, lon, lat, radius, unit)
class SortbyField:
def __init__(self, field: str, asc=True) -> None:
self.args = [field, "ASC" if asc else "DESC"]

View File

@@ -0,0 +1,317 @@
def tags(*t):
"""
Indicate that the values should be matched to a tag field
### Parameters
- **t**: Tags to search for
"""
if not t:
raise ValueError("At least one tag must be specified")
return TagValue(*t)
def between(a, b, inclusive_min=True, inclusive_max=True):
"""
Indicate that value is a numeric range
"""
return RangeValue(a, b, inclusive_min=inclusive_min, inclusive_max=inclusive_max)
def equal(n):
"""
Match a numeric value
"""
return between(n, n)
def lt(n):
"""
Match any value less than n
"""
return between(None, n, inclusive_max=False)
def le(n):
"""
Match any value less or equal to n
"""
return between(None, n, inclusive_max=True)
def gt(n):
"""
Match any value greater than n
"""
return between(n, None, inclusive_min=False)
def ge(n):
"""
Match any value greater or equal to n
"""
return between(n, None, inclusive_min=True)
def geo(lat, lon, radius, unit="km"):
"""
Indicate that value is a geo region
"""
return GeoValue(lat, lon, radius, unit)
class Value:
@property
def combinable(self):
"""
Whether this type of value may be combined with other values
for the same field. This makes the filter potentially more efficient
"""
return False
@staticmethod
def make_value(v):
"""
Convert an object to a value, if it is not a value already
"""
if isinstance(v, Value):
return v
return ScalarValue(v)
def to_string(self):
raise NotImplementedError()
def __str__(self):
return self.to_string()
class RangeValue(Value):
combinable = False
def __init__(self, a, b, inclusive_min=False, inclusive_max=False):
if a is None:
a = "-inf"
if b is None:
b = "inf"
self.range = [str(a), str(b)]
self.inclusive_min = inclusive_min
self.inclusive_max = inclusive_max
def to_string(self):
return "[{1}{0[0]} {2}{0[1]}]".format(
self.range,
"(" if not self.inclusive_min else "",
"(" if not self.inclusive_max else "",
)
class ScalarValue(Value):
combinable = True
def __init__(self, v):
self.v = str(v)
def to_string(self):
return self.v
class TagValue(Value):
combinable = False
def __init__(self, *tags):
self.tags = tags
def to_string(self):
return "{" + " | ".join(str(t) for t in self.tags) + "}"
class GeoValue(Value):
def __init__(self, lon, lat, radius, unit="km"):
self.lon = lon
self.lat = lat
self.radius = radius
self.unit = unit
def to_string(self):
return f"[{self.lon} {self.lat} {self.radius} {self.unit}]"
class Node:
def __init__(self, *children, **kwparams):
"""
Create a node
### Parameters
- **children**: One or more sub-conditions. These can be additional
`intersect`, `disjunct`, `union`, `optional`, or any other `Node`
type.
The semantics of multiple conditions are dependent on the type of
query. For an `intersection` node, this amounts to a logical AND,
for a `union` node, this amounts to a logical `OR`.
- **kwparams**: key-value parameters. Each key is the name of a field,
and the value should be a field value. This can be one of the
following:
- Simple string (for text field matches)
- value returned by one of the helper functions
- list of either a string or a value
### Examples
Field `num` should be between 1 and 10
```
intersect(num=between(1, 10)
```
Name can either be `bob` or `john`
```
union(name=("bob", "john"))
```
Don't select countries in Israel, Japan, or US
```
disjunct_union(country=("il", "jp", "us"))
```
"""
self.params = []
kvparams = {}
for k, v in kwparams.items():
curvals = kvparams.setdefault(k, [])
if isinstance(v, (str, int, float)):
curvals.append(Value.make_value(v))
elif isinstance(v, Value):
curvals.append(v)
else:
curvals.extend(Value.make_value(subv) for subv in v)
self.params += [Node.to_node(p) for p in children]
for k, v in kvparams.items():
self.params.extend(self.join_fields(k, v))
def join_fields(self, key, vals):
if len(vals) == 1:
return [BaseNode(f"@{key}:{vals[0].to_string()}")]
if not vals[0].combinable:
return [BaseNode(f"@{key}:{v.to_string()}") for v in vals]
s = BaseNode(f"@{key}:({self.JOINSTR.join(v.to_string() for v in vals)})")
return [s]
@classmethod
def to_node(cls, obj): # noqa
if isinstance(obj, Node):
return obj
return BaseNode(obj)
@property
def JOINSTR(self):
raise NotImplementedError()
def to_string(self, with_parens=None):
with_parens = self._should_use_paren(with_parens)
pre, post = ("(", ")") if with_parens else ("", "")
return f"{pre}{self.JOINSTR.join(n.to_string() for n in self.params)}{post}"
def _should_use_paren(self, optval):
if optval is not None:
return optval
return len(self.params) > 1
def __str__(self):
return self.to_string()
class BaseNode(Node):
def __init__(self, s):
super().__init__()
self.s = str(s)
def to_string(self, with_parens=None):
return self.s
class IntersectNode(Node):
"""
Create an intersection node. All children need to be satisfied in order for
this node to evaluate as true
"""
JOINSTR = " "
class UnionNode(Node):
"""
Create a union node. Any of the children need to be satisfied in order for
this node to evaluate as true
"""
JOINSTR = "|"
class DisjunctNode(IntersectNode):
"""
Create a disjunct node. In order for this node to be true, all of its
children must evaluate to false
"""
def to_string(self, with_parens=None):
with_parens = self._should_use_paren(with_parens)
ret = super().to_string(with_parens=False)
if with_parens:
return "(-" + ret + ")"
else:
return "-" + ret
class DistjunctUnion(DisjunctNode):
"""
This node is true if *all* of its children are false. This is equivalent to
```
disjunct(union(...))
```
"""
JOINSTR = "|"
class OptionalNode(IntersectNode):
"""
Create an optional node. If this nodes evaluates to true, then the document
will be rated higher in score/rank.
"""
def to_string(self, with_parens=None):
with_parens = self._should_use_paren(with_parens)
ret = super().to_string(with_parens=False)
if with_parens:
return "(~" + ret + ")"
else:
return "~" + ret
def intersect(*args, **kwargs):
return IntersectNode(*args, **kwargs)
def union(*args, **kwargs):
return UnionNode(*args, **kwargs)
def disjunct(*args, **kwargs):
return DisjunctNode(*args, **kwargs)
def disjunct_union(*args, **kwargs):
return DistjunctUnion(*args, **kwargs)
def querystring(*args, **kwargs):
return intersect(*args, **kwargs).to_string()

View File

@@ -0,0 +1,182 @@
from typing import Union
from .aggregation import Asc, Desc, Reducer, SortDirection
class FieldOnlyReducer(Reducer):
"""See https://redis.io/docs/interact/search-and-query/search/aggregations/"""
def __init__(self, field: str) -> None:
super().__init__(field)
self._field = field
class count(Reducer):
"""
Counts the number of results in the group
"""
NAME = "COUNT"
def __init__(self) -> None:
super().__init__()
class sum(FieldOnlyReducer):
"""
Calculates the sum of all the values in the given fields within the group
"""
NAME = "SUM"
def __init__(self, field: str) -> None:
super().__init__(field)
class min(FieldOnlyReducer):
"""
Calculates the smallest value in the given field within the group
"""
NAME = "MIN"
def __init__(self, field: str) -> None:
super().__init__(field)
class max(FieldOnlyReducer):
"""
Calculates the largest value in the given field within the group
"""
NAME = "MAX"
def __init__(self, field: str) -> None:
super().__init__(field)
class avg(FieldOnlyReducer):
"""
Calculates the mean value in the given field within the group
"""
NAME = "AVG"
def __init__(self, field: str) -> None:
super().__init__(field)
class tolist(FieldOnlyReducer):
"""
Returns all the matched properties in a list
"""
NAME = "TOLIST"
def __init__(self, field: str) -> None:
super().__init__(field)
class count_distinct(FieldOnlyReducer):
"""
Calculate the number of distinct values contained in all the results in
the group for the given field
"""
NAME = "COUNT_DISTINCT"
def __init__(self, field: str) -> None:
super().__init__(field)
class count_distinctish(FieldOnlyReducer):
"""
Calculate the number of distinct values contained in all the results in the
group for the given field. This uses a faster algorithm than
`count_distinct` but is less accurate
"""
NAME = "COUNT_DISTINCTISH"
class quantile(Reducer):
"""
Return the value for the nth percentile within the range of values for the
field within the group.
"""
NAME = "QUANTILE"
def __init__(self, field: str, pct: float) -> None:
super().__init__(field, str(pct))
self._field = field
class stddev(FieldOnlyReducer):
"""
Return the standard deviation for the values within the group
"""
NAME = "STDDEV"
def __init__(self, field: str) -> None:
super().__init__(field)
class first_value(Reducer):
"""
Selects the first value within the group according to sorting parameters
"""
NAME = "FIRST_VALUE"
def __init__(self, field: str, *byfields: Union[Asc, Desc]) -> None:
"""
Selects the first value of the given field within the group.
### Parameter
- **field**: Source field used for the value
- **byfields**: How to sort the results. This can be either the
*class* of `aggregation.Asc` or `aggregation.Desc` in which
case the field `field` is also used as the sort input.
`byfields` can also be one or more *instances* of `Asc` or `Desc`
indicating the sort order for these fields
"""
fieldstrs = []
if (
len(byfields) == 1
and isinstance(byfields[0], type)
and issubclass(byfields[0], SortDirection)
):
byfields = [byfields[0](field)]
for f in byfields:
fieldstrs += [f.field, f.DIRSTRING]
args = [field]
if fieldstrs:
args += ["BY"] + fieldstrs
super().__init__(*args)
self._field = field
class random_sample(Reducer):
"""
Returns a random sample of items from the dataset, from the given property
"""
NAME = "RANDOM_SAMPLE"
def __init__(self, field: str, size: int) -> None:
"""
### Parameter
**field**: Field to sample from
**size**: Return this many items (can be less)
"""
args = [field, str(size)]
super().__init__(*args)
self._field = field

View File

@@ -0,0 +1,87 @@
from typing import Optional
from ._util import to_string
from .document import Document
class Result:
"""
Represents the result of a search query, and has an array of Document
objects
"""
def __init__(
self,
res,
hascontent,
duration=0,
has_payload=False,
with_scores=False,
field_encodings: Optional[dict] = None,
):
"""
- duration: the execution time of the query
- has_payload: whether the query has payloads
- with_scores: whether the query has scores
- field_encodings: a dictionary of field encodings if any is provided
"""
self.total = res[0]
self.duration = duration
self.docs = []
step = 1
if hascontent:
step = step + 1
if has_payload:
step = step + 1
if with_scores:
step = step + 1
offset = 2 if with_scores else 1
for i in range(1, len(res), step):
id = to_string(res[i])
payload = to_string(res[i + offset]) if has_payload else None
# fields_offset = 2 if has_payload else 1
fields_offset = offset + 1 if has_payload else offset
score = float(res[i + 1]) if with_scores else None
fields = {}
if hascontent and res[i + fields_offset] is not None:
keys = map(to_string, res[i + fields_offset][::2])
values = res[i + fields_offset][1::2]
for key, value in zip(keys, values):
if field_encodings is None or key not in field_encodings:
fields[key] = to_string(value)
continue
encoding = field_encodings[key]
# If the encoding is None, we don't need to decode the value
if encoding is None:
fields[key] = value
else:
fields[key] = to_string(value, encoding=encoding)
try:
del fields["id"]
except KeyError:
pass
try:
fields["json"] = fields["$"]
del fields["$"]
except KeyError:
pass
doc = (
Document(id, score=score, payload=payload, **fields)
if with_scores
else Document(id, payload=payload, **fields)
)
self.docs.append(doc)
def __repr__(self) -> str:
return f"Result{{{self.total} total, docs: {self.docs}}}"

View File

@@ -0,0 +1,55 @@
from typing import Optional
from ._util import to_string
class Suggestion:
"""
Represents a single suggestion being sent or returned from the
autocomplete server
"""
def __init__(
self, string: str, score: float = 1.0, payload: Optional[str] = None
) -> None:
self.string = to_string(string)
self.payload = to_string(payload)
self.score = score
def __repr__(self) -> str:
return self.string
class SuggestionParser:
"""
Internal class used to parse results from the `SUGGET` command.
This needs to consume either 1, 2, or 3 values at a time from
the return value depending on what objects were requested
"""
def __init__(self, with_scores: bool, with_payloads: bool, ret) -> None:
self.with_scores = with_scores
self.with_payloads = with_payloads
if with_scores and with_payloads:
self.sugsize = 3
self._scoreidx = 1
self._payloadidx = 2
elif with_scores:
self.sugsize = 2
self._scoreidx = 1
elif with_payloads:
self.sugsize = 2
self._payloadidx = 1
else:
self.sugsize = 1
self._scoreidx = -1
self._sugs = ret
def __iter__(self):
for i in range(0, len(self._sugs), self.sugsize):
ss = self._sugs[i]
score = float(self._sugs[i + self._scoreidx]) if self.with_scores else 1.0
payload = self._sugs[i + self._payloadidx] if self.with_payloads else None
yield Suggestion(ss, score, payload)