fix
This commit is contained in:
@@ -0,0 +1,189 @@
|
||||
import redis
|
||||
|
||||
from ...asyncio.client import Pipeline as AsyncioPipeline
|
||||
from .commands import (
|
||||
AGGREGATE_CMD,
|
||||
CONFIG_CMD,
|
||||
INFO_CMD,
|
||||
PROFILE_CMD,
|
||||
SEARCH_CMD,
|
||||
SPELLCHECK_CMD,
|
||||
SYNDUMP_CMD,
|
||||
AsyncSearchCommands,
|
||||
SearchCommands,
|
||||
)
|
||||
|
||||
|
||||
class Search(SearchCommands):
|
||||
"""
|
||||
Create a client for talking to search.
|
||||
It abstracts the API of the module and lets you just use the engine.
|
||||
"""
|
||||
|
||||
class BatchIndexer:
|
||||
"""
|
||||
A batch indexer allows you to automatically batch
|
||||
document indexing in pipelines, flushing it every N documents.
|
||||
"""
|
||||
|
||||
def __init__(self, client, chunk_size=1000):
|
||||
self.client = client
|
||||
self.execute_command = client.execute_command
|
||||
self._pipeline = client.pipeline(transaction=False, shard_hint=None)
|
||||
self.total = 0
|
||||
self.chunk_size = chunk_size
|
||||
self.current_chunk = 0
|
||||
|
||||
def __del__(self):
|
||||
if self.current_chunk:
|
||||
self.commit()
|
||||
|
||||
def add_document(
|
||||
self,
|
||||
doc_id,
|
||||
nosave=False,
|
||||
score=1.0,
|
||||
payload=None,
|
||||
replace=False,
|
||||
partial=False,
|
||||
no_create=False,
|
||||
**fields,
|
||||
):
|
||||
"""
|
||||
Add a document to the batch query
|
||||
"""
|
||||
self.client._add_document(
|
||||
doc_id,
|
||||
conn=self._pipeline,
|
||||
nosave=nosave,
|
||||
score=score,
|
||||
payload=payload,
|
||||
replace=replace,
|
||||
partial=partial,
|
||||
no_create=no_create,
|
||||
**fields,
|
||||
)
|
||||
self.current_chunk += 1
|
||||
self.total += 1
|
||||
if self.current_chunk >= self.chunk_size:
|
||||
self.commit()
|
||||
|
||||
def add_document_hash(self, doc_id, score=1.0, replace=False):
|
||||
"""
|
||||
Add a hash to the batch query
|
||||
"""
|
||||
self.client._add_document_hash(
|
||||
doc_id, conn=self._pipeline, score=score, replace=replace
|
||||
)
|
||||
self.current_chunk += 1
|
||||
self.total += 1
|
||||
if self.current_chunk >= self.chunk_size:
|
||||
self.commit()
|
||||
|
||||
def commit(self):
|
||||
"""
|
||||
Manually commit and flush the batch indexing query
|
||||
"""
|
||||
self._pipeline.execute()
|
||||
self.current_chunk = 0
|
||||
|
||||
def __init__(self, client, index_name="idx"):
|
||||
"""
|
||||
Create a new Client for the given index_name.
|
||||
The default name is `idx`
|
||||
|
||||
If conn is not None, we employ an already existing redis connection
|
||||
"""
|
||||
self._MODULE_CALLBACKS = {}
|
||||
self.client = client
|
||||
self.index_name = index_name
|
||||
self.execute_command = client.execute_command
|
||||
self._pipeline = client.pipeline
|
||||
self._RESP2_MODULE_CALLBACKS = {
|
||||
INFO_CMD: self._parse_info,
|
||||
SEARCH_CMD: self._parse_search,
|
||||
AGGREGATE_CMD: self._parse_aggregate,
|
||||
PROFILE_CMD: self._parse_profile,
|
||||
SPELLCHECK_CMD: self._parse_spellcheck,
|
||||
CONFIG_CMD: self._parse_config_get,
|
||||
SYNDUMP_CMD: self._parse_syndump,
|
||||
}
|
||||
|
||||
def pipeline(self, transaction=True, shard_hint=None):
|
||||
"""Creates a pipeline for the SEARCH module, that can be used for executing
|
||||
SEARCH commands, as well as classic core commands.
|
||||
"""
|
||||
p = Pipeline(
|
||||
connection_pool=self.client.connection_pool,
|
||||
response_callbacks=self._MODULE_CALLBACKS,
|
||||
transaction=transaction,
|
||||
shard_hint=shard_hint,
|
||||
)
|
||||
p.index_name = self.index_name
|
||||
return p
|
||||
|
||||
|
||||
class AsyncSearch(Search, AsyncSearchCommands):
|
||||
class BatchIndexer(Search.BatchIndexer):
|
||||
"""
|
||||
A batch indexer allows you to automatically batch
|
||||
document indexing in pipelines, flushing it every N documents.
|
||||
"""
|
||||
|
||||
async def add_document(
|
||||
self,
|
||||
doc_id,
|
||||
nosave=False,
|
||||
score=1.0,
|
||||
payload=None,
|
||||
replace=False,
|
||||
partial=False,
|
||||
no_create=False,
|
||||
**fields,
|
||||
):
|
||||
"""
|
||||
Add a document to the batch query
|
||||
"""
|
||||
self.client._add_document(
|
||||
doc_id,
|
||||
conn=self._pipeline,
|
||||
nosave=nosave,
|
||||
score=score,
|
||||
payload=payload,
|
||||
replace=replace,
|
||||
partial=partial,
|
||||
no_create=no_create,
|
||||
**fields,
|
||||
)
|
||||
self.current_chunk += 1
|
||||
self.total += 1
|
||||
if self.current_chunk >= self.chunk_size:
|
||||
await self.commit()
|
||||
|
||||
async def commit(self):
|
||||
"""
|
||||
Manually commit and flush the batch indexing query
|
||||
"""
|
||||
await self._pipeline.execute()
|
||||
self.current_chunk = 0
|
||||
|
||||
def pipeline(self, transaction=True, shard_hint=None):
|
||||
"""Creates a pipeline for the SEARCH module, that can be used for executing
|
||||
SEARCH commands, as well as classic core commands.
|
||||
"""
|
||||
p = AsyncPipeline(
|
||||
connection_pool=self.client.connection_pool,
|
||||
response_callbacks=self._MODULE_CALLBACKS,
|
||||
transaction=transaction,
|
||||
shard_hint=shard_hint,
|
||||
)
|
||||
p.index_name = self.index_name
|
||||
return p
|
||||
|
||||
|
||||
class Pipeline(SearchCommands, redis.client.Pipeline):
|
||||
"""Pipeline for the module."""
|
||||
|
||||
|
||||
class AsyncPipeline(AsyncSearchCommands, AsyncioPipeline, Pipeline):
|
||||
"""AsyncPipeline for the module."""
|
||||
@@ -0,0 +1,7 @@
|
||||
def to_string(s, encoding: str = "utf-8"):
|
||||
if isinstance(s, str):
|
||||
return s
|
||||
elif isinstance(s, bytes):
|
||||
return s.decode(encoding, "ignore")
|
||||
else:
|
||||
return s # Not a string we care about
|
||||
@@ -0,0 +1,399 @@
|
||||
from typing import List, Union
|
||||
|
||||
FIELDNAME = object()
|
||||
|
||||
|
||||
class Limit:
|
||||
def __init__(self, offset: int = 0, count: int = 0) -> None:
|
||||
self.offset = offset
|
||||
self.count = count
|
||||
|
||||
def build_args(self):
|
||||
if self.count:
|
||||
return ["LIMIT", str(self.offset), str(self.count)]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
class Reducer:
|
||||
"""
|
||||
Base reducer object for all reducers.
|
||||
|
||||
See the `redisearch.reducers` module for the actual reducers.
|
||||
"""
|
||||
|
||||
NAME = None
|
||||
|
||||
def __init__(self, *args: List[str]) -> None:
|
||||
self._args = args
|
||||
self._field = None
|
||||
self._alias = None
|
||||
|
||||
def alias(self, alias: str) -> "Reducer":
|
||||
"""
|
||||
Set the alias for this reducer.
|
||||
|
||||
### Parameters
|
||||
|
||||
- **alias**: The value of the alias for this reducer. If this is the
|
||||
special value `aggregation.FIELDNAME` then this reducer will be
|
||||
aliased using the same name as the field upon which it operates.
|
||||
Note that using `FIELDNAME` is only possible on reducers which
|
||||
operate on a single field value.
|
||||
|
||||
This method returns the `Reducer` object making it suitable for
|
||||
chaining.
|
||||
"""
|
||||
if alias is FIELDNAME:
|
||||
if not self._field:
|
||||
raise ValueError("Cannot use FIELDNAME alias with no field")
|
||||
# Chop off initial '@'
|
||||
alias = self._field[1:]
|
||||
self._alias = alias
|
||||
return self
|
||||
|
||||
@property
|
||||
def args(self) -> List[str]:
|
||||
return self._args
|
||||
|
||||
|
||||
class SortDirection:
|
||||
"""
|
||||
This special class is used to indicate sort direction.
|
||||
"""
|
||||
|
||||
DIRSTRING = None
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
self.field = field
|
||||
|
||||
|
||||
class Asc(SortDirection):
|
||||
"""
|
||||
Indicate that the given field should be sorted in ascending order
|
||||
"""
|
||||
|
||||
DIRSTRING = "ASC"
|
||||
|
||||
|
||||
class Desc(SortDirection):
|
||||
"""
|
||||
Indicate that the given field should be sorted in descending order
|
||||
"""
|
||||
|
||||
DIRSTRING = "DESC"
|
||||
|
||||
|
||||
class AggregateRequest:
|
||||
"""
|
||||
Aggregation request which can be passed to `Client.aggregate`.
|
||||
"""
|
||||
|
||||
def __init__(self, query: str = "*") -> None:
|
||||
"""
|
||||
Create an aggregation request. This request may then be passed to
|
||||
`client.aggregate()`.
|
||||
|
||||
In order for the request to be usable, it must contain at least one
|
||||
group.
|
||||
|
||||
- **query** Query string for filtering records.
|
||||
|
||||
All member methods (except `build_args()`)
|
||||
return the object itself, making them useful for chaining.
|
||||
"""
|
||||
self._query = query
|
||||
self._aggregateplan = []
|
||||
self._loadfields = []
|
||||
self._loadall = False
|
||||
self._max = 0
|
||||
self._with_schema = False
|
||||
self._verbatim = False
|
||||
self._cursor = []
|
||||
self._dialect = None
|
||||
self._add_scores = False
|
||||
self._scorer = "TFIDF"
|
||||
|
||||
def load(self, *fields: List[str]) -> "AggregateRequest":
|
||||
"""
|
||||
Indicate the fields to be returned in the response. These fields are
|
||||
returned in addition to any others implicitly specified.
|
||||
|
||||
### Parameters
|
||||
|
||||
- **fields**: If fields not specified, all the fields will be loaded.
|
||||
Otherwise, fields should be given in the format of `@field`.
|
||||
"""
|
||||
if fields:
|
||||
self._loadfields.extend(fields)
|
||||
else:
|
||||
self._loadall = True
|
||||
return self
|
||||
|
||||
def group_by(
|
||||
self, fields: List[str], *reducers: Union[Reducer, List[Reducer]]
|
||||
) -> "AggregateRequest":
|
||||
"""
|
||||
Specify by which fields to group the aggregation.
|
||||
|
||||
### Parameters
|
||||
|
||||
- **fields**: Fields to group by. This can either be a single string,
|
||||
or a list of strings. both cases, the field should be specified as
|
||||
`@field`.
|
||||
- **reducers**: One or more reducers. Reducers may be found in the
|
||||
`aggregation` module.
|
||||
"""
|
||||
fields = [fields] if isinstance(fields, str) else fields
|
||||
reducers = [reducers] if isinstance(reducers, Reducer) else reducers
|
||||
|
||||
ret = ["GROUPBY", str(len(fields)), *fields]
|
||||
for reducer in reducers:
|
||||
ret += ["REDUCE", reducer.NAME, str(len(reducer.args))]
|
||||
ret.extend(reducer.args)
|
||||
if reducer._alias is not None:
|
||||
ret += ["AS", reducer._alias]
|
||||
|
||||
self._aggregateplan.extend(ret)
|
||||
return self
|
||||
|
||||
def apply(self, **kwexpr) -> "AggregateRequest":
|
||||
"""
|
||||
Specify one or more projection expressions to add to each result
|
||||
|
||||
### Parameters
|
||||
|
||||
- **kwexpr**: One or more key-value pairs for a projection. The key is
|
||||
the alias for the projection, and the value is the projection
|
||||
expression itself, for example `apply(square_root="sqrt(@foo)")`
|
||||
"""
|
||||
for alias, expr in kwexpr.items():
|
||||
ret = ["APPLY", expr]
|
||||
if alias is not None:
|
||||
ret += ["AS", alias]
|
||||
self._aggregateplan.extend(ret)
|
||||
|
||||
return self
|
||||
|
||||
def limit(self, offset: int, num: int) -> "AggregateRequest":
|
||||
"""
|
||||
Sets the limit for the most recent group or query.
|
||||
|
||||
If no group has been defined yet (via `group_by()`) then this sets
|
||||
the limit for the initial pool of results from the query. Otherwise,
|
||||
this limits the number of items operated on from the previous group.
|
||||
|
||||
Setting a limit on the initial search results may be useful when
|
||||
attempting to execute an aggregation on a sample of a large data set.
|
||||
|
||||
### Parameters
|
||||
|
||||
- **offset**: Result offset from which to begin paging
|
||||
- **num**: Number of results to return
|
||||
|
||||
|
||||
Example of sorting the initial results:
|
||||
|
||||
```
|
||||
AggregateRequest("@sale_amount:[10000, inf]")\
|
||||
.limit(0, 10)\
|
||||
.group_by("@state", r.count())
|
||||
```
|
||||
|
||||
Will only group by the states found in the first 10 results of the
|
||||
query `@sale_amount:[10000, inf]`. On the other hand,
|
||||
|
||||
```
|
||||
AggregateRequest("@sale_amount:[10000, inf]")\
|
||||
.limit(0, 1000)\
|
||||
.group_by("@state", r.count()\
|
||||
.limit(0, 10)
|
||||
```
|
||||
|
||||
Will group all the results matching the query, but only return the
|
||||
first 10 groups.
|
||||
|
||||
If you only wish to return a *top-N* style query, consider using
|
||||
`sort_by()` instead.
|
||||
|
||||
"""
|
||||
_limit = Limit(offset, num)
|
||||
self._aggregateplan.extend(_limit.build_args())
|
||||
return self
|
||||
|
||||
def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest":
|
||||
"""
|
||||
Indicate how the results should be sorted. This can also be used for
|
||||
*top-N* style queries
|
||||
|
||||
### Parameters
|
||||
|
||||
- **fields**: The fields by which to sort. This can be either a single
|
||||
field or a list of fields. If you wish to specify order, you can
|
||||
use the `Asc` or `Desc` wrapper classes.
|
||||
- **max**: Maximum number of results to return. This can be
|
||||
used instead of `LIMIT` and is also faster.
|
||||
|
||||
|
||||
Example of sorting by `foo` ascending and `bar` descending:
|
||||
|
||||
```
|
||||
sort_by(Asc("@foo"), Desc("@bar"))
|
||||
```
|
||||
|
||||
Return the top 10 customers:
|
||||
|
||||
```
|
||||
AggregateRequest()\
|
||||
.group_by("@customer", r.sum("@paid").alias(FIELDNAME))\
|
||||
.sort_by(Desc("@paid"), max=10)
|
||||
```
|
||||
"""
|
||||
if isinstance(fields, (str, SortDirection)):
|
||||
fields = [fields]
|
||||
|
||||
fields_args = []
|
||||
for f in fields:
|
||||
if isinstance(f, SortDirection):
|
||||
fields_args += [f.field, f.DIRSTRING]
|
||||
else:
|
||||
fields_args += [f]
|
||||
|
||||
ret = ["SORTBY", str(len(fields_args))]
|
||||
ret.extend(fields_args)
|
||||
max = kwargs.get("max", 0)
|
||||
if max > 0:
|
||||
ret += ["MAX", str(max)]
|
||||
|
||||
self._aggregateplan.extend(ret)
|
||||
return self
|
||||
|
||||
def filter(self, expressions: Union[str, List[str]]) -> "AggregateRequest":
|
||||
"""
|
||||
Specify filter for post-query results using predicates relating to
|
||||
values in the result set.
|
||||
|
||||
### Parameters
|
||||
|
||||
- **fields**: Fields to group by. This can either be a single string,
|
||||
or a list of strings.
|
||||
"""
|
||||
if isinstance(expressions, str):
|
||||
expressions = [expressions]
|
||||
|
||||
for expression in expressions:
|
||||
self._aggregateplan.extend(["FILTER", expression])
|
||||
|
||||
return self
|
||||
|
||||
def with_schema(self) -> "AggregateRequest":
|
||||
"""
|
||||
If set, the `schema` property will contain a list of `[field, type]`
|
||||
entries in the result object.
|
||||
"""
|
||||
self._with_schema = True
|
||||
return self
|
||||
|
||||
def add_scores(self) -> "AggregateRequest":
|
||||
"""
|
||||
If set, includes the score as an ordinary field of the row.
|
||||
"""
|
||||
self._add_scores = True
|
||||
return self
|
||||
|
||||
def scorer(self, scorer: str) -> "AggregateRequest":
|
||||
"""
|
||||
Use a different scoring function to evaluate document relevance.
|
||||
Default is `TFIDF`.
|
||||
|
||||
:param scorer: The scoring function to use
|
||||
(e.g. `TFIDF.DOCNORM` or `BM25`)
|
||||
"""
|
||||
self._scorer = scorer
|
||||
return self
|
||||
|
||||
def verbatim(self) -> "AggregateRequest":
|
||||
self._verbatim = True
|
||||
return self
|
||||
|
||||
def cursor(self, count: int = 0, max_idle: float = 0.0) -> "AggregateRequest":
|
||||
args = ["WITHCURSOR"]
|
||||
if count:
|
||||
args += ["COUNT", str(count)]
|
||||
if max_idle:
|
||||
args += ["MAXIDLE", str(max_idle * 1000)]
|
||||
self._cursor = args
|
||||
return self
|
||||
|
||||
def build_args(self) -> List[str]:
|
||||
# @foo:bar ...
|
||||
ret = [self._query]
|
||||
|
||||
if self._with_schema:
|
||||
ret.append("WITHSCHEMA")
|
||||
|
||||
if self._verbatim:
|
||||
ret.append("VERBATIM")
|
||||
|
||||
if self._scorer:
|
||||
ret.extend(["SCORER", self._scorer])
|
||||
|
||||
if self._add_scores:
|
||||
ret.append("ADDSCORES")
|
||||
|
||||
if self._cursor:
|
||||
ret += self._cursor
|
||||
|
||||
if self._loadall:
|
||||
ret.append("LOAD")
|
||||
ret.append("*")
|
||||
|
||||
elif self._loadfields:
|
||||
ret.append("LOAD")
|
||||
ret.append(str(len(self._loadfields)))
|
||||
ret.extend(self._loadfields)
|
||||
|
||||
if self._dialect:
|
||||
ret.extend(["DIALECT", self._dialect])
|
||||
|
||||
ret.extend(self._aggregateplan)
|
||||
|
||||
return ret
|
||||
|
||||
def dialect(self, dialect: int) -> "AggregateRequest":
|
||||
"""
|
||||
Add a dialect field to the aggregate command.
|
||||
|
||||
- **dialect** - dialect version to execute the query under
|
||||
"""
|
||||
self._dialect = dialect
|
||||
return self
|
||||
|
||||
|
||||
class Cursor:
|
||||
def __init__(self, cid: int) -> None:
|
||||
self.cid = cid
|
||||
self.max_idle = 0
|
||||
self.count = 0
|
||||
|
||||
def build_args(self):
|
||||
args = [str(self.cid)]
|
||||
if self.max_idle:
|
||||
args += ["MAXIDLE", str(self.max_idle)]
|
||||
if self.count:
|
||||
args += ["COUNT", str(self.count)]
|
||||
return args
|
||||
|
||||
|
||||
class AggregateResult:
|
||||
def __init__(self, rows, cursor: Cursor, schema) -> None:
|
||||
self.rows = rows
|
||||
self.cursor = cursor
|
||||
self.schema = schema
|
||||
|
||||
def __repr__(self) -> (str, str):
|
||||
cid = self.cursor.cid if self.cursor else -1
|
||||
return (
|
||||
f"<{self.__class__.__name__} at 0x{id(self):x} "
|
||||
f"Rows={len(self.rows)}, Cursor={cid}>"
|
||||
)
|
||||
1129
venv/lib/python3.11/site-packages/redis/commands/search/commands.py
Normal file
1129
venv/lib/python3.11/site-packages/redis/commands/search/commands.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,17 @@
|
||||
class Document:
|
||||
"""
|
||||
Represents a single document in a result set
|
||||
"""
|
||||
|
||||
def __init__(self, id, payload=None, **fields):
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
for k, v in fields.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Document {self.__dict__}"
|
||||
|
||||
def __getitem__(self, item):
|
||||
value = getattr(self, item)
|
||||
return value
|
||||
210
venv/lib/python3.11/site-packages/redis/commands/search/field.py
Normal file
210
venv/lib/python3.11/site-packages/redis/commands/search/field.py
Normal file
@@ -0,0 +1,210 @@
|
||||
from typing import List
|
||||
|
||||
from redis import DataError
|
||||
|
||||
|
||||
class Field:
|
||||
"""
|
||||
A class representing a field in a document.
|
||||
"""
|
||||
|
||||
NUMERIC = "NUMERIC"
|
||||
TEXT = "TEXT"
|
||||
WEIGHT = "WEIGHT"
|
||||
GEO = "GEO"
|
||||
TAG = "TAG"
|
||||
VECTOR = "VECTOR"
|
||||
SORTABLE = "SORTABLE"
|
||||
NOINDEX = "NOINDEX"
|
||||
AS = "AS"
|
||||
GEOSHAPE = "GEOSHAPE"
|
||||
INDEX_MISSING = "INDEXMISSING"
|
||||
INDEX_EMPTY = "INDEXEMPTY"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
args: List[str] = None,
|
||||
sortable: bool = False,
|
||||
no_index: bool = False,
|
||||
index_missing: bool = False,
|
||||
index_empty: bool = False,
|
||||
as_name: str = None,
|
||||
):
|
||||
"""
|
||||
Create a new field object.
|
||||
|
||||
Args:
|
||||
name: The name of the field.
|
||||
args:
|
||||
sortable: If `True`, the field will be sortable.
|
||||
no_index: If `True`, the field will not be indexed.
|
||||
index_missing: If `True`, it will be possible to search for documents that
|
||||
have this field missing.
|
||||
index_empty: If `True`, it will be possible to search for documents that
|
||||
have this field empty.
|
||||
as_name: If provided, this alias will be used for the field.
|
||||
"""
|
||||
if args is None:
|
||||
args = []
|
||||
self.name = name
|
||||
self.args = args
|
||||
self.args_suffix = list()
|
||||
self.as_name = as_name
|
||||
|
||||
if sortable:
|
||||
self.args_suffix.append(Field.SORTABLE)
|
||||
if no_index:
|
||||
self.args_suffix.append(Field.NOINDEX)
|
||||
if index_missing:
|
||||
self.args_suffix.append(Field.INDEX_MISSING)
|
||||
if index_empty:
|
||||
self.args_suffix.append(Field.INDEX_EMPTY)
|
||||
|
||||
if no_index and not sortable:
|
||||
raise ValueError("Non-Sortable non-Indexable fields are ignored")
|
||||
|
||||
def append_arg(self, value):
|
||||
self.args.append(value)
|
||||
|
||||
def redis_args(self):
|
||||
args = [self.name]
|
||||
if self.as_name:
|
||||
args += [self.AS, self.as_name]
|
||||
args += self.args
|
||||
args += self.args_suffix
|
||||
return args
|
||||
|
||||
|
||||
class TextField(Field):
|
||||
"""
|
||||
TextField is used to define a text field in a schema definition
|
||||
"""
|
||||
|
||||
NOSTEM = "NOSTEM"
|
||||
PHONETIC = "PHONETIC"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
weight: float = 1.0,
|
||||
no_stem: bool = False,
|
||||
phonetic_matcher: str = None,
|
||||
withsuffixtrie: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
Field.__init__(self, name, args=[Field.TEXT, Field.WEIGHT, weight], **kwargs)
|
||||
|
||||
if no_stem:
|
||||
Field.append_arg(self, self.NOSTEM)
|
||||
if phonetic_matcher and phonetic_matcher in [
|
||||
"dm:en",
|
||||
"dm:fr",
|
||||
"dm:pt",
|
||||
"dm:es",
|
||||
]:
|
||||
Field.append_arg(self, self.PHONETIC)
|
||||
Field.append_arg(self, phonetic_matcher)
|
||||
if withsuffixtrie:
|
||||
Field.append_arg(self, "WITHSUFFIXTRIE")
|
||||
|
||||
|
||||
class NumericField(Field):
|
||||
"""
|
||||
NumericField is used to define a numeric field in a schema definition
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
Field.__init__(self, name, args=[Field.NUMERIC], **kwargs)
|
||||
|
||||
|
||||
class GeoShapeField(Field):
|
||||
"""
|
||||
GeoShapeField is used to enable within/contain indexing/searching
|
||||
"""
|
||||
|
||||
SPHERICAL = "SPHERICAL"
|
||||
FLAT = "FLAT"
|
||||
|
||||
def __init__(self, name: str, coord_system=None, **kwargs):
|
||||
args = [Field.GEOSHAPE]
|
||||
if coord_system:
|
||||
args.append(coord_system)
|
||||
Field.__init__(self, name, args=args, **kwargs)
|
||||
|
||||
|
||||
class GeoField(Field):
|
||||
"""
|
||||
GeoField is used to define a geo-indexing field in a schema definition
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
Field.__init__(self, name, args=[Field.GEO], **kwargs)
|
||||
|
||||
|
||||
class TagField(Field):
|
||||
"""
|
||||
TagField is a tag-indexing field with simpler compression and tokenization.
|
||||
See http://redisearch.io/Tags/
|
||||
"""
|
||||
|
||||
SEPARATOR = "SEPARATOR"
|
||||
CASESENSITIVE = "CASESENSITIVE"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
separator: str = ",",
|
||||
case_sensitive: bool = False,
|
||||
withsuffixtrie: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
args = [Field.TAG, self.SEPARATOR, separator]
|
||||
if case_sensitive:
|
||||
args.append(self.CASESENSITIVE)
|
||||
if withsuffixtrie:
|
||||
args.append("WITHSUFFIXTRIE")
|
||||
|
||||
Field.__init__(self, name, args=args, **kwargs)
|
||||
|
||||
|
||||
class VectorField(Field):
|
||||
"""
|
||||
Allows vector similarity queries against the value in this attribute.
|
||||
See https://oss.redis.com/redisearch/Vectors/#vector_fields.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, algorithm: str, attributes: dict, **kwargs):
|
||||
"""
|
||||
Create Vector Field. Notice that Vector cannot have sortable or no_index tag,
|
||||
although it's also a Field.
|
||||
|
||||
``name`` is the name of the field.
|
||||
|
||||
``algorithm`` can be "FLAT" or "HNSW".
|
||||
|
||||
``attributes`` each algorithm can have specific attributes. Some of them
|
||||
are mandatory and some of them are optional. See
|
||||
https://oss.redis.com/redisearch/master/Vectors/#specific_creation_attributes_per_algorithm
|
||||
for more information.
|
||||
"""
|
||||
sort = kwargs.get("sortable", False)
|
||||
noindex = kwargs.get("no_index", False)
|
||||
|
||||
if sort or noindex:
|
||||
raise DataError("Cannot set 'sortable' or 'no_index' in Vector fields.")
|
||||
|
||||
if algorithm.upper() not in ["FLAT", "HNSW"]:
|
||||
raise DataError(
|
||||
"Realtime vector indexing supporting 2 Indexing Methods:"
|
||||
"'FLAT' and 'HNSW'."
|
||||
)
|
||||
|
||||
attr_li = []
|
||||
|
||||
for key, value in attributes.items():
|
||||
attr_li.extend([key, value])
|
||||
|
||||
Field.__init__(
|
||||
self, name, args=[Field.VECTOR, algorithm, len(attr_li), *attr_li], **kwargs
|
||||
)
|
||||
@@ -0,0 +1,79 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class IndexType(Enum):
|
||||
"""Enum of the currently supported index types."""
|
||||
|
||||
HASH = 1
|
||||
JSON = 2
|
||||
|
||||
|
||||
class IndexDefinition:
|
||||
"""IndexDefinition is used to define a index definition for automatic
|
||||
indexing on Hash or Json update."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefix=[],
|
||||
filter=None,
|
||||
language_field=None,
|
||||
language=None,
|
||||
score_field=None,
|
||||
score=1.0,
|
||||
payload_field=None,
|
||||
index_type=None,
|
||||
):
|
||||
self.args = []
|
||||
self._append_index_type(index_type)
|
||||
self._append_prefix(prefix)
|
||||
self._append_filter(filter)
|
||||
self._append_language(language_field, language)
|
||||
self._append_score(score_field, score)
|
||||
self._append_payload(payload_field)
|
||||
|
||||
def _append_index_type(self, index_type):
|
||||
"""Append `ON HASH` or `ON JSON` according to the enum."""
|
||||
if index_type is IndexType.HASH:
|
||||
self.args.extend(["ON", "HASH"])
|
||||
elif index_type is IndexType.JSON:
|
||||
self.args.extend(["ON", "JSON"])
|
||||
elif index_type is not None:
|
||||
raise RuntimeError(f"index_type must be one of {list(IndexType)}")
|
||||
|
||||
def _append_prefix(self, prefix):
|
||||
"""Append PREFIX."""
|
||||
if len(prefix) > 0:
|
||||
self.args.append("PREFIX")
|
||||
self.args.append(len(prefix))
|
||||
for p in prefix:
|
||||
self.args.append(p)
|
||||
|
||||
def _append_filter(self, filter):
|
||||
"""Append FILTER."""
|
||||
if filter is not None:
|
||||
self.args.append("FILTER")
|
||||
self.args.append(filter)
|
||||
|
||||
def _append_language(self, language_field, language):
|
||||
"""Append LANGUAGE_FIELD and LANGUAGE."""
|
||||
if language_field is not None:
|
||||
self.args.append("LANGUAGE_FIELD")
|
||||
self.args.append(language_field)
|
||||
if language is not None:
|
||||
self.args.append("LANGUAGE")
|
||||
self.args.append(language)
|
||||
|
||||
def _append_score(self, score_field, score):
|
||||
"""Append SCORE_FIELD and SCORE."""
|
||||
if score_field is not None:
|
||||
self.args.append("SCORE_FIELD")
|
||||
self.args.append(score_field)
|
||||
if score is not None:
|
||||
self.args.append("SCORE")
|
||||
self.args.append(score)
|
||||
|
||||
def _append_payload(self, payload_field):
|
||||
"""Append PAYLOAD_FIELD."""
|
||||
if payload_field is not None:
|
||||
self.args.append("PAYLOAD_FIELD")
|
||||
self.args.append(payload_field)
|
||||
377
venv/lib/python3.11/site-packages/redis/commands/search/query.py
Normal file
377
venv/lib/python3.11/site-packages/redis/commands/search/query.py
Normal file
@@ -0,0 +1,377 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
class Query:
|
||||
"""
|
||||
Query is used to build complex queries that have more parameters than just
|
||||
the query string. The query string is set in the constructor, and other
|
||||
options have setter functions.
|
||||
|
||||
The setter functions return the query object, so they can be chained,
|
||||
i.e. `Query("foo").verbatim().filter(...)` etc.
|
||||
"""
|
||||
|
||||
def __init__(self, query_string: str) -> None:
|
||||
"""
|
||||
Create a new query object.
|
||||
The query string is set in the constructor, and other options have
|
||||
setter functions.
|
||||
"""
|
||||
|
||||
self._query_string: str = query_string
|
||||
self._offset: int = 0
|
||||
self._num: int = 10
|
||||
self._no_content: bool = False
|
||||
self._no_stopwords: bool = False
|
||||
self._fields: Optional[List[str]] = None
|
||||
self._verbatim: bool = False
|
||||
self._with_payloads: bool = False
|
||||
self._with_scores: bool = False
|
||||
self._scorer: Optional[str] = None
|
||||
self._filters: List = list()
|
||||
self._ids: Optional[List[str]] = None
|
||||
self._slop: int = -1
|
||||
self._timeout: Optional[float] = None
|
||||
self._in_order: bool = False
|
||||
self._sortby: Optional[SortbyField] = None
|
||||
self._return_fields: List = []
|
||||
self._return_fields_decode_as: dict = {}
|
||||
self._summarize_fields: List = []
|
||||
self._highlight_fields: List = []
|
||||
self._language: Optional[str] = None
|
||||
self._expander: Optional[str] = None
|
||||
self._dialect: Optional[int] = None
|
||||
|
||||
def query_string(self) -> str:
|
||||
"""Return the query string of this query only."""
|
||||
return self._query_string
|
||||
|
||||
def limit_ids(self, *ids) -> "Query":
|
||||
"""Limit the results to a specific set of pre-known document
|
||||
ids of any length."""
|
||||
self._ids = ids
|
||||
return self
|
||||
|
||||
def return_fields(self, *fields) -> "Query":
|
||||
"""Add fields to return fields."""
|
||||
for field in fields:
|
||||
self.return_field(field)
|
||||
return self
|
||||
|
||||
def return_field(
|
||||
self,
|
||||
field: str,
|
||||
as_field: Optional[str] = None,
|
||||
decode_field: Optional[bool] = True,
|
||||
encoding: Optional[str] = "utf8",
|
||||
) -> "Query":
|
||||
"""
|
||||
Add a field to the list of fields to return.
|
||||
|
||||
- **field**: The field to include in query results
|
||||
- **as_field**: The alias for the field
|
||||
- **decode_field**: Whether to decode the field from bytes to string
|
||||
- **encoding**: The encoding to use when decoding the field
|
||||
"""
|
||||
self._return_fields.append(field)
|
||||
self._return_fields_decode_as[field] = encoding if decode_field else None
|
||||
if as_field is not None:
|
||||
self._return_fields += ("AS", as_field)
|
||||
return self
|
||||
|
||||
def _mk_field_list(self, fields: List[str]) -> List:
|
||||
if not fields:
|
||||
return []
|
||||
return [fields] if isinstance(fields, str) else list(fields)
|
||||
|
||||
def summarize(
|
||||
self,
|
||||
fields: Optional[List] = None,
|
||||
context_len: Optional[int] = None,
|
||||
num_frags: Optional[int] = None,
|
||||
sep: Optional[str] = None,
|
||||
) -> "Query":
|
||||
"""
|
||||
Return an abridged format of the field, containing only the segments of
|
||||
the field which contain the matching term(s).
|
||||
|
||||
If `fields` is specified, then only the mentioned fields are
|
||||
summarized; otherwise all results are summarized.
|
||||
|
||||
Server side defaults are used for each option (except `fields`)
|
||||
if not specified
|
||||
|
||||
- **fields** List of fields to summarize. All fields are summarized
|
||||
if not specified
|
||||
- **context_len** Amount of context to include with each fragment
|
||||
- **num_frags** Number of fragments per document
|
||||
- **sep** Separator string to separate fragments
|
||||
"""
|
||||
args = ["SUMMARIZE"]
|
||||
fields = self._mk_field_list(fields)
|
||||
if fields:
|
||||
args += ["FIELDS", str(len(fields))] + fields
|
||||
|
||||
if context_len is not None:
|
||||
args += ["LEN", str(context_len)]
|
||||
if num_frags is not None:
|
||||
args += ["FRAGS", str(num_frags)]
|
||||
if sep is not None:
|
||||
args += ["SEPARATOR", sep]
|
||||
|
||||
self._summarize_fields = args
|
||||
return self
|
||||
|
||||
def highlight(
|
||||
self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Apply specified markup to matched term(s) within the returned field(s).
|
||||
|
||||
- **fields** If specified then only those mentioned fields are
|
||||
highlighted, otherwise all fields are highlighted
|
||||
- **tags** A list of two strings to surround the match.
|
||||
"""
|
||||
args = ["HIGHLIGHT"]
|
||||
fields = self._mk_field_list(fields)
|
||||
if fields:
|
||||
args += ["FIELDS", str(len(fields))] + fields
|
||||
if tags:
|
||||
args += ["TAGS"] + list(tags)
|
||||
|
||||
self._highlight_fields = args
|
||||
return self
|
||||
|
||||
def language(self, language: str) -> "Query":
|
||||
"""
|
||||
Analyze the query as being in the specified language.
|
||||
|
||||
:param language: The language (e.g. `chinese` or `english`)
|
||||
"""
|
||||
self._language = language
|
||||
return self
|
||||
|
||||
def slop(self, slop: int) -> "Query":
|
||||
"""Allow a maximum of N intervening non matched terms between
|
||||
phrase terms (0 means exact phrase).
|
||||
"""
|
||||
self._slop = slop
|
||||
return self
|
||||
|
||||
def timeout(self, timeout: float) -> "Query":
|
||||
"""overrides the timeout parameter of the module"""
|
||||
self._timeout = timeout
|
||||
return self
|
||||
|
||||
def in_order(self) -> "Query":
|
||||
"""
|
||||
Match only documents where the query terms appear in
|
||||
the same order in the document.
|
||||
i.e. for the query "hello world", we do not match "world hello"
|
||||
"""
|
||||
self._in_order = True
|
||||
return self
|
||||
|
||||
def scorer(self, scorer: str) -> "Query":
|
||||
"""
|
||||
Use a different scoring function to evaluate document relevance.
|
||||
Default is `TFIDF`.
|
||||
|
||||
:param scorer: The scoring function to use
|
||||
(e.g. `TFIDF.DOCNORM` or `BM25`)
|
||||
"""
|
||||
self._scorer = scorer
|
||||
return self
|
||||
|
||||
def get_args(self) -> List[str]:
|
||||
"""Format the redis arguments for this query and return them."""
|
||||
args = [self._query_string]
|
||||
args += self._get_args_tags()
|
||||
args += self._summarize_fields + self._highlight_fields
|
||||
args += ["LIMIT", self._offset, self._num]
|
||||
return args
|
||||
|
||||
def _get_args_tags(self) -> List[str]:
|
||||
args = []
|
||||
if self._no_content:
|
||||
args.append("NOCONTENT")
|
||||
if self._fields:
|
||||
args.append("INFIELDS")
|
||||
args.append(len(self._fields))
|
||||
args += self._fields
|
||||
if self._verbatim:
|
||||
args.append("VERBATIM")
|
||||
if self._no_stopwords:
|
||||
args.append("NOSTOPWORDS")
|
||||
if self._filters:
|
||||
for flt in self._filters:
|
||||
if not isinstance(flt, Filter):
|
||||
raise AttributeError("Did not receive a Filter object.")
|
||||
args += flt.args
|
||||
if self._with_payloads:
|
||||
args.append("WITHPAYLOADS")
|
||||
if self._scorer:
|
||||
args += ["SCORER", self._scorer]
|
||||
if self._with_scores:
|
||||
args.append("WITHSCORES")
|
||||
if self._ids:
|
||||
args.append("INKEYS")
|
||||
args.append(len(self._ids))
|
||||
args += self._ids
|
||||
if self._slop >= 0:
|
||||
args += ["SLOP", self._slop]
|
||||
if self._timeout is not None:
|
||||
args += ["TIMEOUT", self._timeout]
|
||||
if self._in_order:
|
||||
args.append("INORDER")
|
||||
if self._return_fields:
|
||||
args.append("RETURN")
|
||||
args.append(len(self._return_fields))
|
||||
args += self._return_fields
|
||||
if self._sortby:
|
||||
if not isinstance(self._sortby, SortbyField):
|
||||
raise AttributeError("Did not receive a SortByField.")
|
||||
args.append("SORTBY")
|
||||
args += self._sortby.args
|
||||
if self._language:
|
||||
args += ["LANGUAGE", self._language]
|
||||
if self._expander:
|
||||
args += ["EXPANDER", self._expander]
|
||||
if self._dialect:
|
||||
args += ["DIALECT", self._dialect]
|
||||
|
||||
return args
|
||||
|
||||
def paging(self, offset: int, num: int) -> "Query":
|
||||
"""
|
||||
Set the paging for the query (defaults to 0..10).
|
||||
|
||||
- **offset**: Paging offset for the results. Defaults to 0
|
||||
- **num**: How many results do we want
|
||||
"""
|
||||
self._offset = offset
|
||||
self._num = num
|
||||
return self
|
||||
|
||||
def verbatim(self) -> "Query":
|
||||
"""Set the query to be verbatim, i.e. use no query expansion
|
||||
or stemming.
|
||||
"""
|
||||
self._verbatim = True
|
||||
return self
|
||||
|
||||
def no_content(self) -> "Query":
|
||||
"""Set the query to only return ids and not the document content."""
|
||||
self._no_content = True
|
||||
return self
|
||||
|
||||
def no_stopwords(self) -> "Query":
|
||||
"""
|
||||
Prevent the query from being filtered for stopwords.
|
||||
Only useful in very big queries that you are certain contain
|
||||
no stopwords.
|
||||
"""
|
||||
self._no_stopwords = True
|
||||
return self
|
||||
|
||||
def with_payloads(self) -> "Query":
|
||||
"""Ask the engine to return document payloads."""
|
||||
self._with_payloads = True
|
||||
return self
|
||||
|
||||
def with_scores(self) -> "Query":
|
||||
"""Ask the engine to return document search scores."""
|
||||
self._with_scores = True
|
||||
return self
|
||||
|
||||
def limit_fields(self, *fields: List[str]) -> "Query":
|
||||
"""
|
||||
Limit the search to specific TEXT fields only.
|
||||
|
||||
- **fields**: A list of strings, case sensitive field names
|
||||
from the defined schema.
|
||||
"""
|
||||
self._fields = fields
|
||||
return self
|
||||
|
||||
def add_filter(self, flt: "Filter") -> "Query":
|
||||
"""
|
||||
Add a numeric or geo filter to the query.
|
||||
**Currently only one of each filter is supported by the engine**
|
||||
|
||||
- **flt**: A NumericFilter or GeoFilter object, used on a
|
||||
corresponding field
|
||||
"""
|
||||
|
||||
self._filters.append(flt)
|
||||
return self
|
||||
|
||||
def sort_by(self, field: str, asc: bool = True) -> "Query":
|
||||
"""
|
||||
Add a sortby field to the query.
|
||||
|
||||
- **field** - the name of the field to sort by
|
||||
- **asc** - when `True`, sorting will be done in asceding order
|
||||
"""
|
||||
self._sortby = SortbyField(field, asc)
|
||||
return self
|
||||
|
||||
def expander(self, expander: str) -> "Query":
|
||||
"""
|
||||
Add a expander field to the query.
|
||||
|
||||
- **expander** - the name of the expander
|
||||
"""
|
||||
self._expander = expander
|
||||
return self
|
||||
|
||||
def dialect(self, dialect: int) -> "Query":
|
||||
"""
|
||||
Add a dialect field to the query.
|
||||
|
||||
- **dialect** - dialect version to execute the query under
|
||||
"""
|
||||
self._dialect = dialect
|
||||
return self
|
||||
|
||||
|
||||
class Filter:
|
||||
def __init__(self, keyword: str, field: str, *args: List[str]) -> None:
|
||||
self.args = [keyword, field] + list(args)
|
||||
|
||||
|
||||
class NumericFilter(Filter):
|
||||
INF = "+inf"
|
||||
NEG_INF = "-inf"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field: str,
|
||||
minval: Union[int, str],
|
||||
maxval: Union[int, str],
|
||||
minExclusive: bool = False,
|
||||
maxExclusive: bool = False,
|
||||
) -> None:
|
||||
args = [
|
||||
minval if not minExclusive else f"({minval}",
|
||||
maxval if not maxExclusive else f"({maxval}",
|
||||
]
|
||||
|
||||
Filter.__init__(self, "FILTER", field, *args)
|
||||
|
||||
|
||||
class GeoFilter(Filter):
|
||||
METERS = "m"
|
||||
KILOMETERS = "km"
|
||||
FEET = "ft"
|
||||
MILES = "mi"
|
||||
|
||||
def __init__(
|
||||
self, field: str, lon: float, lat: float, radius: float, unit: str = KILOMETERS
|
||||
) -> None:
|
||||
Filter.__init__(self, "GEOFILTER", field, lon, lat, radius, unit)
|
||||
|
||||
|
||||
class SortbyField:
|
||||
def __init__(self, field: str, asc=True) -> None:
|
||||
self.args = [field, "ASC" if asc else "DESC"]
|
||||
@@ -0,0 +1,317 @@
|
||||
def tags(*t):
|
||||
"""
|
||||
Indicate that the values should be matched to a tag field
|
||||
|
||||
### Parameters
|
||||
|
||||
- **t**: Tags to search for
|
||||
"""
|
||||
if not t:
|
||||
raise ValueError("At least one tag must be specified")
|
||||
return TagValue(*t)
|
||||
|
||||
|
||||
def between(a, b, inclusive_min=True, inclusive_max=True):
|
||||
"""
|
||||
Indicate that value is a numeric range
|
||||
"""
|
||||
return RangeValue(a, b, inclusive_min=inclusive_min, inclusive_max=inclusive_max)
|
||||
|
||||
|
||||
def equal(n):
|
||||
"""
|
||||
Match a numeric value
|
||||
"""
|
||||
return between(n, n)
|
||||
|
||||
|
||||
def lt(n):
|
||||
"""
|
||||
Match any value less than n
|
||||
"""
|
||||
return between(None, n, inclusive_max=False)
|
||||
|
||||
|
||||
def le(n):
|
||||
"""
|
||||
Match any value less or equal to n
|
||||
"""
|
||||
return between(None, n, inclusive_max=True)
|
||||
|
||||
|
||||
def gt(n):
|
||||
"""
|
||||
Match any value greater than n
|
||||
"""
|
||||
return between(n, None, inclusive_min=False)
|
||||
|
||||
|
||||
def ge(n):
|
||||
"""
|
||||
Match any value greater or equal to n
|
||||
"""
|
||||
return between(n, None, inclusive_min=True)
|
||||
|
||||
|
||||
def geo(lat, lon, radius, unit="km"):
|
||||
"""
|
||||
Indicate that value is a geo region
|
||||
"""
|
||||
return GeoValue(lat, lon, radius, unit)
|
||||
|
||||
|
||||
class Value:
|
||||
@property
|
||||
def combinable(self):
|
||||
"""
|
||||
Whether this type of value may be combined with other values
|
||||
for the same field. This makes the filter potentially more efficient
|
||||
"""
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def make_value(v):
|
||||
"""
|
||||
Convert an object to a value, if it is not a value already
|
||||
"""
|
||||
if isinstance(v, Value):
|
||||
return v
|
||||
return ScalarValue(v)
|
||||
|
||||
def to_string(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __str__(self):
|
||||
return self.to_string()
|
||||
|
||||
|
||||
class RangeValue(Value):
|
||||
combinable = False
|
||||
|
||||
def __init__(self, a, b, inclusive_min=False, inclusive_max=False):
|
||||
if a is None:
|
||||
a = "-inf"
|
||||
if b is None:
|
||||
b = "inf"
|
||||
self.range = [str(a), str(b)]
|
||||
self.inclusive_min = inclusive_min
|
||||
self.inclusive_max = inclusive_max
|
||||
|
||||
def to_string(self):
|
||||
return "[{1}{0[0]} {2}{0[1]}]".format(
|
||||
self.range,
|
||||
"(" if not self.inclusive_min else "",
|
||||
"(" if not self.inclusive_max else "",
|
||||
)
|
||||
|
||||
|
||||
class ScalarValue(Value):
|
||||
combinable = True
|
||||
|
||||
def __init__(self, v):
|
||||
self.v = str(v)
|
||||
|
||||
def to_string(self):
|
||||
return self.v
|
||||
|
||||
|
||||
class TagValue(Value):
|
||||
combinable = False
|
||||
|
||||
def __init__(self, *tags):
|
||||
self.tags = tags
|
||||
|
||||
def to_string(self):
|
||||
return "{" + " | ".join(str(t) for t in self.tags) + "}"
|
||||
|
||||
|
||||
class GeoValue(Value):
|
||||
def __init__(self, lon, lat, radius, unit="km"):
|
||||
self.lon = lon
|
||||
self.lat = lat
|
||||
self.radius = radius
|
||||
self.unit = unit
|
||||
|
||||
def to_string(self):
|
||||
return f"[{self.lon} {self.lat} {self.radius} {self.unit}]"
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, *children, **kwparams):
|
||||
"""
|
||||
Create a node
|
||||
|
||||
### Parameters
|
||||
|
||||
- **children**: One or more sub-conditions. These can be additional
|
||||
`intersect`, `disjunct`, `union`, `optional`, or any other `Node`
|
||||
type.
|
||||
|
||||
The semantics of multiple conditions are dependent on the type of
|
||||
query. For an `intersection` node, this amounts to a logical AND,
|
||||
for a `union` node, this amounts to a logical `OR`.
|
||||
|
||||
- **kwparams**: key-value parameters. Each key is the name of a field,
|
||||
and the value should be a field value. This can be one of the
|
||||
following:
|
||||
|
||||
- Simple string (for text field matches)
|
||||
- value returned by one of the helper functions
|
||||
- list of either a string or a value
|
||||
|
||||
|
||||
### Examples
|
||||
|
||||
Field `num` should be between 1 and 10
|
||||
```
|
||||
intersect(num=between(1, 10)
|
||||
```
|
||||
|
||||
Name can either be `bob` or `john`
|
||||
|
||||
```
|
||||
union(name=("bob", "john"))
|
||||
```
|
||||
|
||||
Don't select countries in Israel, Japan, or US
|
||||
|
||||
```
|
||||
disjunct_union(country=("il", "jp", "us"))
|
||||
```
|
||||
"""
|
||||
|
||||
self.params = []
|
||||
|
||||
kvparams = {}
|
||||
for k, v in kwparams.items():
|
||||
curvals = kvparams.setdefault(k, [])
|
||||
if isinstance(v, (str, int, float)):
|
||||
curvals.append(Value.make_value(v))
|
||||
elif isinstance(v, Value):
|
||||
curvals.append(v)
|
||||
else:
|
||||
curvals.extend(Value.make_value(subv) for subv in v)
|
||||
|
||||
self.params += [Node.to_node(p) for p in children]
|
||||
|
||||
for k, v in kvparams.items():
|
||||
self.params.extend(self.join_fields(k, v))
|
||||
|
||||
def join_fields(self, key, vals):
|
||||
if len(vals) == 1:
|
||||
return [BaseNode(f"@{key}:{vals[0].to_string()}")]
|
||||
if not vals[0].combinable:
|
||||
return [BaseNode(f"@{key}:{v.to_string()}") for v in vals]
|
||||
s = BaseNode(f"@{key}:({self.JOINSTR.join(v.to_string() for v in vals)})")
|
||||
return [s]
|
||||
|
||||
@classmethod
|
||||
def to_node(cls, obj): # noqa
|
||||
if isinstance(obj, Node):
|
||||
return obj
|
||||
return BaseNode(obj)
|
||||
|
||||
@property
|
||||
def JOINSTR(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_string(self, with_parens=None):
|
||||
with_parens = self._should_use_paren(with_parens)
|
||||
pre, post = ("(", ")") if with_parens else ("", "")
|
||||
return f"{pre}{self.JOINSTR.join(n.to_string() for n in self.params)}{post}"
|
||||
|
||||
def _should_use_paren(self, optval):
|
||||
if optval is not None:
|
||||
return optval
|
||||
return len(self.params) > 1
|
||||
|
||||
def __str__(self):
|
||||
return self.to_string()
|
||||
|
||||
|
||||
class BaseNode(Node):
|
||||
def __init__(self, s):
|
||||
super().__init__()
|
||||
self.s = str(s)
|
||||
|
||||
def to_string(self, with_parens=None):
|
||||
return self.s
|
||||
|
||||
|
||||
class IntersectNode(Node):
|
||||
"""
|
||||
Create an intersection node. All children need to be satisfied in order for
|
||||
this node to evaluate as true
|
||||
"""
|
||||
|
||||
JOINSTR = " "
|
||||
|
||||
|
||||
class UnionNode(Node):
|
||||
"""
|
||||
Create a union node. Any of the children need to be satisfied in order for
|
||||
this node to evaluate as true
|
||||
"""
|
||||
|
||||
JOINSTR = "|"
|
||||
|
||||
|
||||
class DisjunctNode(IntersectNode):
|
||||
"""
|
||||
Create a disjunct node. In order for this node to be true, all of its
|
||||
children must evaluate to false
|
||||
"""
|
||||
|
||||
def to_string(self, with_parens=None):
|
||||
with_parens = self._should_use_paren(with_parens)
|
||||
ret = super().to_string(with_parens=False)
|
||||
if with_parens:
|
||||
return "(-" + ret + ")"
|
||||
else:
|
||||
return "-" + ret
|
||||
|
||||
|
||||
class DistjunctUnion(DisjunctNode):
|
||||
"""
|
||||
This node is true if *all* of its children are false. This is equivalent to
|
||||
```
|
||||
disjunct(union(...))
|
||||
```
|
||||
"""
|
||||
|
||||
JOINSTR = "|"
|
||||
|
||||
|
||||
class OptionalNode(IntersectNode):
|
||||
"""
|
||||
Create an optional node. If this nodes evaluates to true, then the document
|
||||
will be rated higher in score/rank.
|
||||
"""
|
||||
|
||||
def to_string(self, with_parens=None):
|
||||
with_parens = self._should_use_paren(with_parens)
|
||||
ret = super().to_string(with_parens=False)
|
||||
if with_parens:
|
||||
return "(~" + ret + ")"
|
||||
else:
|
||||
return "~" + ret
|
||||
|
||||
|
||||
def intersect(*args, **kwargs):
|
||||
return IntersectNode(*args, **kwargs)
|
||||
|
||||
|
||||
def union(*args, **kwargs):
|
||||
return UnionNode(*args, **kwargs)
|
||||
|
||||
|
||||
def disjunct(*args, **kwargs):
|
||||
return DisjunctNode(*args, **kwargs)
|
||||
|
||||
|
||||
def disjunct_union(*args, **kwargs):
|
||||
return DistjunctUnion(*args, **kwargs)
|
||||
|
||||
|
||||
def querystring(*args, **kwargs):
|
||||
return intersect(*args, **kwargs).to_string()
|
||||
@@ -0,0 +1,182 @@
|
||||
from typing import Union
|
||||
|
||||
from .aggregation import Asc, Desc, Reducer, SortDirection
|
||||
|
||||
|
||||
class FieldOnlyReducer(Reducer):
|
||||
"""See https://redis.io/docs/interact/search-and-query/search/aggregations/"""
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
self._field = field
|
||||
|
||||
|
||||
class count(Reducer):
|
||||
"""
|
||||
Counts the number of results in the group
|
||||
"""
|
||||
|
||||
NAME = "COUNT"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
class sum(FieldOnlyReducer):
|
||||
"""
|
||||
Calculates the sum of all the values in the given fields within the group
|
||||
"""
|
||||
|
||||
NAME = "SUM"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class min(FieldOnlyReducer):
|
||||
"""
|
||||
Calculates the smallest value in the given field within the group
|
||||
"""
|
||||
|
||||
NAME = "MIN"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class max(FieldOnlyReducer):
|
||||
"""
|
||||
Calculates the largest value in the given field within the group
|
||||
"""
|
||||
|
||||
NAME = "MAX"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class avg(FieldOnlyReducer):
|
||||
"""
|
||||
Calculates the mean value in the given field within the group
|
||||
"""
|
||||
|
||||
NAME = "AVG"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class tolist(FieldOnlyReducer):
|
||||
"""
|
||||
Returns all the matched properties in a list
|
||||
"""
|
||||
|
||||
NAME = "TOLIST"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class count_distinct(FieldOnlyReducer):
|
||||
"""
|
||||
Calculate the number of distinct values contained in all the results in
|
||||
the group for the given field
|
||||
"""
|
||||
|
||||
NAME = "COUNT_DISTINCT"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class count_distinctish(FieldOnlyReducer):
|
||||
"""
|
||||
Calculate the number of distinct values contained in all the results in the
|
||||
group for the given field. This uses a faster algorithm than
|
||||
`count_distinct` but is less accurate
|
||||
"""
|
||||
|
||||
NAME = "COUNT_DISTINCTISH"
|
||||
|
||||
|
||||
class quantile(Reducer):
|
||||
"""
|
||||
Return the value for the nth percentile within the range of values for the
|
||||
field within the group.
|
||||
"""
|
||||
|
||||
NAME = "QUANTILE"
|
||||
|
||||
def __init__(self, field: str, pct: float) -> None:
|
||||
super().__init__(field, str(pct))
|
||||
self._field = field
|
||||
|
||||
|
||||
class stddev(FieldOnlyReducer):
|
||||
"""
|
||||
Return the standard deviation for the values within the group
|
||||
"""
|
||||
|
||||
NAME = "STDDEV"
|
||||
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__(field)
|
||||
|
||||
|
||||
class first_value(Reducer):
|
||||
"""
|
||||
Selects the first value within the group according to sorting parameters
|
||||
"""
|
||||
|
||||
NAME = "FIRST_VALUE"
|
||||
|
||||
def __init__(self, field: str, *byfields: Union[Asc, Desc]) -> None:
|
||||
"""
|
||||
Selects the first value of the given field within the group.
|
||||
|
||||
### Parameter
|
||||
|
||||
- **field**: Source field used for the value
|
||||
- **byfields**: How to sort the results. This can be either the
|
||||
*class* of `aggregation.Asc` or `aggregation.Desc` in which
|
||||
case the field `field` is also used as the sort input.
|
||||
|
||||
`byfields` can also be one or more *instances* of `Asc` or `Desc`
|
||||
indicating the sort order for these fields
|
||||
"""
|
||||
|
||||
fieldstrs = []
|
||||
if (
|
||||
len(byfields) == 1
|
||||
and isinstance(byfields[0], type)
|
||||
and issubclass(byfields[0], SortDirection)
|
||||
):
|
||||
byfields = [byfields[0](field)]
|
||||
|
||||
for f in byfields:
|
||||
fieldstrs += [f.field, f.DIRSTRING]
|
||||
|
||||
args = [field]
|
||||
if fieldstrs:
|
||||
args += ["BY"] + fieldstrs
|
||||
super().__init__(*args)
|
||||
self._field = field
|
||||
|
||||
|
||||
class random_sample(Reducer):
|
||||
"""
|
||||
Returns a random sample of items from the dataset, from the given property
|
||||
"""
|
||||
|
||||
NAME = "RANDOM_SAMPLE"
|
||||
|
||||
def __init__(self, field: str, size: int) -> None:
|
||||
"""
|
||||
### Parameter
|
||||
|
||||
**field**: Field to sample from
|
||||
**size**: Return this many items (can be less)
|
||||
"""
|
||||
args = [field, str(size)]
|
||||
super().__init__(*args)
|
||||
self._field = field
|
||||
@@ -0,0 +1,87 @@
|
||||
from typing import Optional
|
||||
|
||||
from ._util import to_string
|
||||
from .document import Document
|
||||
|
||||
|
||||
class Result:
|
||||
"""
|
||||
Represents the result of a search query, and has an array of Document
|
||||
objects
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
res,
|
||||
hascontent,
|
||||
duration=0,
|
||||
has_payload=False,
|
||||
with_scores=False,
|
||||
field_encodings: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
- duration: the execution time of the query
|
||||
- has_payload: whether the query has payloads
|
||||
- with_scores: whether the query has scores
|
||||
- field_encodings: a dictionary of field encodings if any is provided
|
||||
"""
|
||||
|
||||
self.total = res[0]
|
||||
self.duration = duration
|
||||
self.docs = []
|
||||
|
||||
step = 1
|
||||
if hascontent:
|
||||
step = step + 1
|
||||
if has_payload:
|
||||
step = step + 1
|
||||
if with_scores:
|
||||
step = step + 1
|
||||
|
||||
offset = 2 if with_scores else 1
|
||||
|
||||
for i in range(1, len(res), step):
|
||||
id = to_string(res[i])
|
||||
payload = to_string(res[i + offset]) if has_payload else None
|
||||
# fields_offset = 2 if has_payload else 1
|
||||
fields_offset = offset + 1 if has_payload else offset
|
||||
score = float(res[i + 1]) if with_scores else None
|
||||
|
||||
fields = {}
|
||||
if hascontent and res[i + fields_offset] is not None:
|
||||
keys = map(to_string, res[i + fields_offset][::2])
|
||||
values = res[i + fields_offset][1::2]
|
||||
|
||||
for key, value in zip(keys, values):
|
||||
if field_encodings is None or key not in field_encodings:
|
||||
fields[key] = to_string(value)
|
||||
continue
|
||||
|
||||
encoding = field_encodings[key]
|
||||
|
||||
# If the encoding is None, we don't need to decode the value
|
||||
if encoding is None:
|
||||
fields[key] = value
|
||||
else:
|
||||
fields[key] = to_string(value, encoding=encoding)
|
||||
|
||||
try:
|
||||
del fields["id"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
fields["json"] = fields["$"]
|
||||
del fields["$"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
doc = (
|
||||
Document(id, score=score, payload=payload, **fields)
|
||||
if with_scores
|
||||
else Document(id, payload=payload, **fields)
|
||||
)
|
||||
self.docs.append(doc)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Result{{{self.total} total, docs: {self.docs}}}"
|
||||
@@ -0,0 +1,55 @@
|
||||
from typing import Optional
|
||||
|
||||
from ._util import to_string
|
||||
|
||||
|
||||
class Suggestion:
|
||||
"""
|
||||
Represents a single suggestion being sent or returned from the
|
||||
autocomplete server
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, string: str, score: float = 1.0, payload: Optional[str] = None
|
||||
) -> None:
|
||||
self.string = to_string(string)
|
||||
self.payload = to_string(payload)
|
||||
self.score = score
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.string
|
||||
|
||||
|
||||
class SuggestionParser:
|
||||
"""
|
||||
Internal class used to parse results from the `SUGGET` command.
|
||||
This needs to consume either 1, 2, or 3 values at a time from
|
||||
the return value depending on what objects were requested
|
||||
"""
|
||||
|
||||
def __init__(self, with_scores: bool, with_payloads: bool, ret) -> None:
|
||||
self.with_scores = with_scores
|
||||
self.with_payloads = with_payloads
|
||||
|
||||
if with_scores and with_payloads:
|
||||
self.sugsize = 3
|
||||
self._scoreidx = 1
|
||||
self._payloadidx = 2
|
||||
elif with_scores:
|
||||
self.sugsize = 2
|
||||
self._scoreidx = 1
|
||||
elif with_payloads:
|
||||
self.sugsize = 2
|
||||
self._payloadidx = 1
|
||||
else:
|
||||
self.sugsize = 1
|
||||
self._scoreidx = -1
|
||||
|
||||
self._sugs = ret
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(0, len(self._sugs), self.sugsize):
|
||||
ss = self._sugs[i]
|
||||
score = float(self._sugs[i + self._scoreidx]) if self.with_scores else 1.0
|
||||
payload = self._sugs[i + self._payloadidx] if self.with_payloads else None
|
||||
yield Suggestion(ss, score, payload)
|
||||
Reference in New Issue
Block a user