fix
This commit is contained in:
@@ -0,0 +1,257 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Perform aggregation operations on a collection or database."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping, MutableMapping
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pymongo import common
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.read_preferences import ReadPreference, _AggWritePref
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.asynchronous.server import Server
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.typings import _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class _AggregationCommand:
|
||||
"""The internal abstract base class for aggregation cursors.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
:meth:`pymongo.asynchronous.collection.AsyncCollection.aggregate`, or
|
||||
:meth:`pymongo.asynchronous.database.AsyncDatabase.aggregate` instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[AsyncDatabase, AsyncCollection],
|
||||
cursor_class: type[AsyncCommandCursor],
|
||||
pipeline: _Pipeline,
|
||||
options: MutableMapping[str, Any],
|
||||
explicit_session: bool,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
user_fields: Optional[MutableMapping[str, Any]] = None,
|
||||
result_processor: Optional[Callable[[Mapping[str, Any], AsyncConnection], None]] = None,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
if "explain" in options:
|
||||
raise ConfigurationError(
|
||||
"The explain option is not supported. Use AsyncDatabase.command instead."
|
||||
)
|
||||
|
||||
self._target = target
|
||||
|
||||
pipeline = common.validate_list("pipeline", pipeline)
|
||||
self._pipeline = pipeline
|
||||
self._performs_write = False
|
||||
if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]):
|
||||
self._performs_write = True
|
||||
|
||||
common.validate_is_mapping("options", options)
|
||||
if let is not None:
|
||||
common.validate_is_mapping("let", let)
|
||||
options["let"] = let
|
||||
if comment is not None:
|
||||
options["comment"] = comment
|
||||
|
||||
self._options = options
|
||||
|
||||
# This is the batchSize that will be used for setting the initial
|
||||
# batchSize for the cursor, as well as the subsequent getMores.
|
||||
self._batch_size = common.validate_non_negative_integer_or_none(
|
||||
"batchSize", self._options.pop("batchSize", None)
|
||||
)
|
||||
|
||||
# If the cursor option is already specified, avoid overriding it.
|
||||
self._options.setdefault("cursor", {})
|
||||
# If the pipeline performs a write, we ignore the initial batchSize
|
||||
# since the server doesn't return results in this case.
|
||||
if self._batch_size is not None and not self._performs_write:
|
||||
self._options["cursor"]["batchSize"] = self._batch_size
|
||||
|
||||
self._cursor_class = cursor_class
|
||||
self._explicit_session = explicit_session
|
||||
self._user_fields = user_fields
|
||||
self._result_processor = result_processor
|
||||
|
||||
self._collation = validate_collation_or_none(options.pop("collation", None))
|
||||
|
||||
self._max_await_time_ms = options.pop("maxAwaitTimeMS", None)
|
||||
self._write_preference: Optional[_AggWritePref] = None
|
||||
|
||||
@property
|
||||
def _aggregation_target(self) -> Union[str, int]:
|
||||
"""The argument to pass to the aggregate command."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _cursor_namespace(self) -> str:
|
||||
"""The namespace in which the aggregate command is run."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> AsyncCollection:
|
||||
"""The AsyncCollection used for the aggregate command cursor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _database(self) -> AsyncDatabase:
|
||||
"""The database against which the aggregation command is run."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_read_preference(
|
||||
self, session: Optional[AsyncClientSession]
|
||||
) -> Union[_AggWritePref, _ServerMode]:
|
||||
if self._write_preference:
|
||||
return self._write_preference
|
||||
pref = self._target._read_preference_for(session)
|
||||
if self._performs_write and pref != ReadPreference.PRIMARY:
|
||||
self._write_preference = pref = _AggWritePref(pref) # type: ignore[assignment]
|
||||
return pref
|
||||
|
||||
async def get_cursor(
|
||||
self,
|
||||
session: Optional[AsyncClientSession],
|
||||
server: Server,
|
||||
conn: AsyncConnection,
|
||||
read_preference: _ServerMode,
|
||||
) -> AsyncCommandCursor[_DocumentType]:
|
||||
# Serialize command.
|
||||
cmd = {"aggregate": self._aggregation_target, "pipeline": self._pipeline}
|
||||
cmd.update(self._options)
|
||||
|
||||
# Apply this target's read concern if:
|
||||
# readConcern has not been specified as a kwarg and either
|
||||
# - server version is >= 4.2 or
|
||||
# - server version is >= 3.2 and pipeline doesn't use $out
|
||||
if ("readConcern" not in cmd) and (
|
||||
not self._performs_write or (conn.max_wire_version >= 8)
|
||||
):
|
||||
read_concern = self._target.read_concern
|
||||
else:
|
||||
read_concern = None
|
||||
|
||||
# Apply this target's write concern if:
|
||||
# writeConcern has not been specified as a kwarg and pipeline doesn't
|
||||
# perform a write operation
|
||||
if "writeConcern" not in cmd and self._performs_write:
|
||||
write_concern = self._target._write_concern_for(session)
|
||||
else:
|
||||
write_concern = None
|
||||
|
||||
# Run command.
|
||||
result = await conn.command(
|
||||
self._database.name,
|
||||
cmd,
|
||||
read_preference,
|
||||
self._target.codec_options,
|
||||
parse_write_concern_error=True,
|
||||
read_concern=read_concern,
|
||||
write_concern=write_concern,
|
||||
collation=self._collation,
|
||||
session=session,
|
||||
client=self._database.client,
|
||||
user_fields=self._user_fields,
|
||||
)
|
||||
|
||||
if self._result_processor:
|
||||
self._result_processor(result, conn)
|
||||
|
||||
# Extract cursor from result or mock/fake one if necessary.
|
||||
if "cursor" in result:
|
||||
cursor = result["cursor"]
|
||||
else:
|
||||
# Unacknowledged $out/$merge write. Fake a cursor.
|
||||
cursor = {
|
||||
"id": 0,
|
||||
"firstBatch": result.get("result", []),
|
||||
"ns": self._cursor_namespace,
|
||||
}
|
||||
|
||||
# Create and return cursor instance.
|
||||
cmd_cursor = self._cursor_class(
|
||||
self._cursor_collection(cursor),
|
||||
cursor,
|
||||
conn.address,
|
||||
batch_size=self._batch_size or 0,
|
||||
max_await_time_ms=self._max_await_time_ms,
|
||||
session=session,
|
||||
explicit_session=self._explicit_session,
|
||||
comment=self._options.get("comment"),
|
||||
)
|
||||
await cmd_cursor._maybe_pin_connection(conn)
|
||||
return cmd_cursor
|
||||
|
||||
|
||||
class _CollectionAggregationCommand(_AggregationCommand):
|
||||
_target: AsyncCollection
|
||||
|
||||
@property
|
||||
def _aggregation_target(self) -> str:
|
||||
return self._target.name
|
||||
|
||||
@property
|
||||
def _cursor_namespace(self) -> str:
|
||||
return self._target.full_name
|
||||
|
||||
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection:
|
||||
"""The AsyncCollection used for the aggregate command cursor."""
|
||||
return self._target
|
||||
|
||||
@property
|
||||
def _database(self) -> AsyncDatabase:
|
||||
return self._target.database
|
||||
|
||||
|
||||
class _CollectionRawAggregationCommand(_CollectionAggregationCommand):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# For raw-batches, we set the initial batchSize for the cursor to 0.
|
||||
if not self._performs_write:
|
||||
self._options["cursor"]["batchSize"] = 0
|
||||
|
||||
|
||||
class _DatabaseAggregationCommand(_AggregationCommand):
|
||||
_target: AsyncDatabase
|
||||
|
||||
@property
|
||||
def _aggregation_target(self) -> int:
|
||||
return 1
|
||||
|
||||
@property
|
||||
def _cursor_namespace(self) -> str:
|
||||
return f"{self._target.name}.$cmd.aggregate"
|
||||
|
||||
@property
|
||||
def _database(self) -> AsyncDatabase:
|
||||
return self._target
|
||||
|
||||
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection:
|
||||
"""The AsyncCollection used for the aggregate command cursor."""
|
||||
# AsyncCollection level aggregate may not always return the "ns" field
|
||||
# according to our MockupDB tests. Let's handle that case for db level
|
||||
# aggregate too by defaulting to the <db>.$cmd.aggregate namespace.
|
||||
_, collname = cursor.get("ns", self._cursor_namespace).split(".", 1)
|
||||
return self._database[collname]
|
||||
457
venv/lib/python3.11/site-packages/pymongo/asynchronous/auth.py
Normal file
457
venv/lib/python3.11/site-packages/pymongo/asynchronous/auth.py
Normal file
@@ -0,0 +1,457 @@
|
||||
# Copyright 2013-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import hmac
|
||||
import socket
|
||||
from base64 import standard_b64decode, standard_b64encode
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import quote
|
||||
|
||||
from bson.binary import Binary
|
||||
from pymongo.asynchronous.auth_aws import _authenticate_aws
|
||||
from pymongo.asynchronous.auth_oidc import (
|
||||
_authenticate_oidc,
|
||||
_get_authenticator,
|
||||
)
|
||||
from pymongo.auth_shared import (
|
||||
MongoCredential,
|
||||
_authenticate_scram_start,
|
||||
_parse_scram_response,
|
||||
_xor,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.saslprep import saslprep
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.hello import Hello
|
||||
|
||||
HAVE_KERBEROS = True
|
||||
_USE_PRINCIPAL = False
|
||||
try:
|
||||
import winkerberos as kerberos # type:ignore[import]
|
||||
|
||||
if tuple(map(int, kerberos.__version__.split(".")[:2])) >= (0, 5):
|
||||
_USE_PRINCIPAL = True
|
||||
except ImportError:
|
||||
try:
|
||||
import kerberos # type:ignore[import]
|
||||
except ImportError:
|
||||
HAVE_KERBEROS = False
|
||||
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
async def _authenticate_scram(
|
||||
credentials: MongoCredential, conn: AsyncConnection, mechanism: str
|
||||
) -> None:
|
||||
"""Authenticate using SCRAM."""
|
||||
username = credentials.username
|
||||
if mechanism == "SCRAM-SHA-256":
|
||||
digest = "sha256"
|
||||
digestmod = hashlib.sha256
|
||||
data = saslprep(credentials.password).encode("utf-8")
|
||||
else:
|
||||
digest = "sha1"
|
||||
digestmod = hashlib.sha1
|
||||
data = _password_digest(username, credentials.password).encode("utf-8")
|
||||
source = credentials.source
|
||||
cache = credentials.cache
|
||||
|
||||
# Make local
|
||||
_hmac = hmac.HMAC
|
||||
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
assert isinstance(ctx, _ScramContext)
|
||||
assert ctx.scram_data is not None
|
||||
nonce, first_bare = ctx.scram_data
|
||||
res = ctx.speculative_authenticate
|
||||
else:
|
||||
nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism)
|
||||
res = await conn.command(source, cmd)
|
||||
|
||||
assert res is not None
|
||||
server_first = res["payload"]
|
||||
parsed = _parse_scram_response(server_first)
|
||||
iterations = int(parsed[b"i"])
|
||||
if iterations < 4096:
|
||||
raise OperationFailure("Server returned an invalid iteration count.")
|
||||
salt = parsed[b"s"]
|
||||
rnonce = parsed[b"r"]
|
||||
if not rnonce.startswith(nonce):
|
||||
raise OperationFailure("Server returned an invalid nonce.")
|
||||
|
||||
without_proof = b"c=biws,r=" + rnonce
|
||||
if cache.data:
|
||||
client_key, server_key, csalt, citerations = cache.data
|
||||
else:
|
||||
client_key, server_key, csalt, citerations = None, None, None, None
|
||||
|
||||
# Salt and / or iterations could change for a number of different
|
||||
# reasons. Either changing invalidates the cache.
|
||||
if not client_key or salt != csalt or iterations != citerations:
|
||||
salted_pass = hashlib.pbkdf2_hmac(digest, data, standard_b64decode(salt), iterations)
|
||||
client_key = _hmac(salted_pass, b"Client Key", digestmod).digest()
|
||||
server_key = _hmac(salted_pass, b"Server Key", digestmod).digest()
|
||||
cache.data = (client_key, server_key, salt, iterations)
|
||||
stored_key = digestmod(client_key).digest()
|
||||
auth_msg = b",".join((first_bare, server_first, without_proof))
|
||||
client_sig = _hmac(stored_key, auth_msg, digestmod).digest()
|
||||
client_proof = b"p=" + standard_b64encode(_xor(client_key, client_sig))
|
||||
client_final = b",".join((without_proof, client_proof))
|
||||
|
||||
server_sig = standard_b64encode(_hmac(server_key, auth_msg, digestmod).digest())
|
||||
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": res["conversationId"],
|
||||
"payload": Binary(client_final),
|
||||
}
|
||||
res = await conn.command(source, cmd)
|
||||
|
||||
parsed = _parse_scram_response(res["payload"])
|
||||
if not hmac.compare_digest(parsed[b"v"], server_sig):
|
||||
raise OperationFailure("Server returned an invalid signature.")
|
||||
|
||||
# A third empty challenge may be required if the server does not support
|
||||
# skipEmptyExchange: SERVER-44857.
|
||||
if not res["done"]:
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": res["conversationId"],
|
||||
"payload": Binary(b""),
|
||||
}
|
||||
res = await conn.command(source, cmd)
|
||||
if not res["done"]:
|
||||
raise OperationFailure("SASL conversation failed to complete.")
|
||||
|
||||
|
||||
def _password_digest(username: str, password: str) -> str:
|
||||
"""Get a password digest to use for authentication."""
|
||||
if not isinstance(password, str):
|
||||
raise TypeError("password must be an instance of str")
|
||||
if len(password) == 0:
|
||||
raise ValueError("password can't be empty")
|
||||
if not isinstance(username, str):
|
||||
raise TypeError("username must be an instance of str")
|
||||
|
||||
md5hash = hashlib.md5() # noqa: S324
|
||||
data = f"{username}:mongo:{password}"
|
||||
md5hash.update(data.encode("utf-8"))
|
||||
return md5hash.hexdigest()
|
||||
|
||||
|
||||
def _auth_key(nonce: str, username: str, password: str) -> str:
|
||||
"""Get an auth key to use for authentication."""
|
||||
digest = _password_digest(username, password)
|
||||
md5hash = hashlib.md5() # noqa: S324
|
||||
data = f"{nonce}{username}{digest}"
|
||||
md5hash.update(data.encode("utf-8"))
|
||||
return md5hash.hexdigest()
|
||||
|
||||
|
||||
def _canonicalize_hostname(hostname: str) -> str:
|
||||
"""Canonicalize hostname following MIT-krb5 behavior."""
|
||||
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
|
||||
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
|
||||
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
|
||||
)[0]
|
||||
|
||||
try:
|
||||
name = socket.getnameinfo(sockaddr, socket.NI_NAMEREQD)
|
||||
except socket.gaierror:
|
||||
return canonname.lower()
|
||||
|
||||
return name[0].lower()
|
||||
|
||||
|
||||
async def _authenticate_gssapi(credentials: MongoCredential, conn: AsyncConnection) -> None:
|
||||
"""Authenticate using GSSAPI."""
|
||||
if not HAVE_KERBEROS:
|
||||
raise ConfigurationError(
|
||||
'The "kerberos" module must be installed to use GSSAPI authentication.'
|
||||
)
|
||||
|
||||
try:
|
||||
username = credentials.username
|
||||
password = credentials.password
|
||||
props = credentials.mechanism_properties
|
||||
# Starting here and continuing through the while loop below - establish
|
||||
# the security context. See RFC 4752, Section 3.1, first paragraph.
|
||||
host = conn.address[0]
|
||||
if props.canonicalize_host_name:
|
||||
host = _canonicalize_hostname(host)
|
||||
service = props.service_name + "@" + host
|
||||
if props.service_realm is not None:
|
||||
service = service + "@" + props.service_realm
|
||||
|
||||
if password is not None:
|
||||
if _USE_PRINCIPAL:
|
||||
# Note that, though we use unquote_plus for unquoting URI
|
||||
# options, we use quote here. Microsoft's UrlUnescape (used
|
||||
# by WinKerberos) doesn't support +.
|
||||
principal = ":".join((quote(username), quote(password)))
|
||||
result, ctx = kerberos.authGSSClientInit(
|
||||
service, principal, gssflags=kerberos.GSS_C_MUTUAL_FLAG
|
||||
)
|
||||
else:
|
||||
if "@" in username:
|
||||
user, domain = username.split("@", 1)
|
||||
else:
|
||||
user, domain = username, None
|
||||
result, ctx = kerberos.authGSSClientInit(
|
||||
service,
|
||||
gssflags=kerberos.GSS_C_MUTUAL_FLAG,
|
||||
user=user,
|
||||
domain=domain,
|
||||
password=password,
|
||||
)
|
||||
else:
|
||||
result, ctx = kerberos.authGSSClientInit(service, gssflags=kerberos.GSS_C_MUTUAL_FLAG)
|
||||
|
||||
if result != kerberos.AUTH_GSS_COMPLETE:
|
||||
raise OperationFailure("Kerberos context failed to initialize.")
|
||||
|
||||
try:
|
||||
# pykerberos uses a weird mix of exceptions and return values
|
||||
# to indicate errors.
|
||||
# 0 == continue, 1 == complete, -1 == error
|
||||
# Only authGSSClientStep can return 0.
|
||||
if kerberos.authGSSClientStep(ctx, "") != 0:
|
||||
raise OperationFailure("Unknown kerberos failure in step function.")
|
||||
|
||||
# Start a SASL conversation with mongod/s
|
||||
# Note: pykerberos deals with base64 encoded byte strings.
|
||||
# Since mongo accepts base64 strings as the payload we don't
|
||||
# have to use bson.binary.Binary.
|
||||
payload = kerberos.authGSSClientResponse(ctx)
|
||||
cmd = {
|
||||
"saslStart": 1,
|
||||
"mechanism": "GSSAPI",
|
||||
"payload": payload,
|
||||
"autoAuthorize": 1,
|
||||
}
|
||||
response = await conn.command("$external", cmd)
|
||||
|
||||
# Limit how many times we loop to catch protocol / library issues
|
||||
for _ in range(10):
|
||||
result = kerberos.authGSSClientStep(ctx, str(response["payload"]))
|
||||
if result == -1:
|
||||
raise OperationFailure("Unknown kerberos failure in step function.")
|
||||
|
||||
payload = kerberos.authGSSClientResponse(ctx) or ""
|
||||
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": response["conversationId"],
|
||||
"payload": payload,
|
||||
}
|
||||
response = await conn.command("$external", cmd)
|
||||
|
||||
if result == kerberos.AUTH_GSS_COMPLETE:
|
||||
break
|
||||
else:
|
||||
raise OperationFailure("Kerberos authentication failed to complete.")
|
||||
|
||||
# Once the security context is established actually authenticate.
|
||||
# See RFC 4752, Section 3.1, last two paragraphs.
|
||||
if kerberos.authGSSClientUnwrap(ctx, str(response["payload"])) != 1:
|
||||
raise OperationFailure("Unknown kerberos failure during GSS_Unwrap step.")
|
||||
|
||||
if kerberos.authGSSClientWrap(ctx, kerberos.authGSSClientResponse(ctx), username) != 1:
|
||||
raise OperationFailure("Unknown kerberos failure during GSS_Wrap step.")
|
||||
|
||||
payload = kerberos.authGSSClientResponse(ctx)
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": response["conversationId"],
|
||||
"payload": payload,
|
||||
}
|
||||
await conn.command("$external", cmd)
|
||||
|
||||
finally:
|
||||
kerberos.authGSSClientClean(ctx)
|
||||
|
||||
except kerberos.KrbError as exc:
|
||||
raise OperationFailure(str(exc)) from None
|
||||
|
||||
|
||||
async def _authenticate_plain(credentials: MongoCredential, conn: AsyncConnection) -> None:
|
||||
"""Authenticate using SASL PLAIN (RFC 4616)"""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
password = credentials.password
|
||||
payload = (f"\x00{username}\x00{password}").encode()
|
||||
cmd = {
|
||||
"saslStart": 1,
|
||||
"mechanism": "PLAIN",
|
||||
"payload": Binary(payload),
|
||||
"autoAuthorize": 1,
|
||||
}
|
||||
await conn.command(source, cmd)
|
||||
|
||||
|
||||
async def _authenticate_x509(credentials: MongoCredential, conn: AsyncConnection) -> None:
|
||||
"""Authenticate using MONGODB-X509."""
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
# MONGODB-X509 is done after the speculative auth step.
|
||||
return
|
||||
|
||||
cmd = _X509Context(credentials, conn.address).speculate_command()
|
||||
await conn.command("$external", cmd)
|
||||
|
||||
|
||||
async def _authenticate_mongo_cr(credentials: MongoCredential, conn: AsyncConnection) -> None:
|
||||
"""Authenticate using MONGODB-CR."""
|
||||
source = credentials.source
|
||||
username = credentials.username
|
||||
password = credentials.password
|
||||
# Get a nonce
|
||||
response = await conn.command(source, {"getnonce": 1})
|
||||
nonce = response["nonce"]
|
||||
key = _auth_key(nonce, username, password)
|
||||
|
||||
# Actually authenticate
|
||||
query = {"authenticate": 1, "user": username, "nonce": nonce, "key": key}
|
||||
await conn.command(source, query)
|
||||
|
||||
|
||||
async def _authenticate_default(credentials: MongoCredential, conn: AsyncConnection) -> None:
|
||||
if conn.max_wire_version >= 7:
|
||||
if conn.negotiated_mechs:
|
||||
mechs = conn.negotiated_mechs
|
||||
else:
|
||||
source = credentials.source
|
||||
cmd = conn.hello_cmd()
|
||||
cmd["saslSupportedMechs"] = source + "." + credentials.username
|
||||
mechs = (await conn.command(source, cmd, publish_events=False)).get(
|
||||
"saslSupportedMechs", []
|
||||
)
|
||||
if "SCRAM-SHA-256" in mechs:
|
||||
return await _authenticate_scram(credentials, conn, "SCRAM-SHA-256")
|
||||
else:
|
||||
return await _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
|
||||
else:
|
||||
return await _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
|
||||
|
||||
|
||||
_AUTH_MAP: Mapping[str, Callable[..., Coroutine[Any, Any, None]]] = {
|
||||
"GSSAPI": _authenticate_gssapi,
|
||||
"MONGODB-CR": _authenticate_mongo_cr,
|
||||
"MONGODB-X509": _authenticate_x509,
|
||||
"MONGODB-AWS": _authenticate_aws,
|
||||
"MONGODB-OIDC": _authenticate_oidc, # type:ignore[dict-item]
|
||||
"PLAIN": _authenticate_plain,
|
||||
"SCRAM-SHA-1": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-1"),
|
||||
"SCRAM-SHA-256": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-256"),
|
||||
"DEFAULT": _authenticate_default,
|
||||
}
|
||||
|
||||
|
||||
class _AuthContext:
|
||||
def __init__(self, credentials: MongoCredential, address: tuple[str, int]) -> None:
|
||||
self.credentials = credentials
|
||||
self.speculative_authenticate: Optional[Mapping[str, Any]] = None
|
||||
self.address = address
|
||||
|
||||
@staticmethod
|
||||
def from_credentials(
|
||||
creds: MongoCredential, address: tuple[str, int]
|
||||
) -> Optional[_AuthContext]:
|
||||
spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism)
|
||||
if spec_cls:
|
||||
return cast(_AuthContext, spec_cls(creds, address))
|
||||
return None
|
||||
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_response(self, hello: Hello[Mapping[str, Any]]) -> None:
|
||||
self.speculative_authenticate = hello.speculative_authenticate
|
||||
|
||||
def speculate_succeeded(self) -> bool:
|
||||
return bool(self.speculative_authenticate)
|
||||
|
||||
|
||||
class _ScramContext(_AuthContext):
|
||||
def __init__(
|
||||
self, credentials: MongoCredential, address: tuple[str, int], mechanism: str
|
||||
) -> None:
|
||||
super().__init__(credentials, address)
|
||||
self.scram_data: Optional[tuple[bytes, bytes]] = None
|
||||
self.mechanism = mechanism
|
||||
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
nonce, first_bare, cmd = _authenticate_scram_start(self.credentials, self.mechanism)
|
||||
# The 'db' field is included only on the speculative command.
|
||||
cmd["db"] = self.credentials.source
|
||||
# Save for later use.
|
||||
self.scram_data = (nonce, first_bare)
|
||||
return cmd
|
||||
|
||||
|
||||
class _X509Context(_AuthContext):
|
||||
def speculate_command(self) -> MutableMapping[str, Any]:
|
||||
cmd = {"authenticate": 1, "mechanism": "MONGODB-X509"}
|
||||
if self.credentials.username is not None:
|
||||
cmd["user"] = self.credentials.username
|
||||
return cmd
|
||||
|
||||
|
||||
class _OIDCContext(_AuthContext):
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
authenticator = _get_authenticator(self.credentials, self.address)
|
||||
cmd = authenticator.get_spec_auth_cmd()
|
||||
if cmd is None:
|
||||
return None
|
||||
cmd["db"] = self.credentials.source
|
||||
return cmd
|
||||
|
||||
|
||||
_SPECULATIVE_AUTH_MAP: Mapping[str, Any] = {
|
||||
"MONGODB-X509": _X509Context,
|
||||
"SCRAM-SHA-1": functools.partial(_ScramContext, mechanism="SCRAM-SHA-1"),
|
||||
"SCRAM-SHA-256": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
|
||||
"MONGODB-OIDC": _OIDCContext,
|
||||
"DEFAULT": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
|
||||
}
|
||||
|
||||
|
||||
async def authenticate(
|
||||
credentials: MongoCredential, conn: AsyncConnection, reauthenticate: bool = False
|
||||
) -> None:
|
||||
"""Authenticate connection."""
|
||||
mechanism = credentials.mechanism
|
||||
auth_func = _AUTH_MAP[mechanism]
|
||||
if mechanism == "MONGODB-OIDC":
|
||||
await _authenticate_oidc(credentials, conn, reauthenticate)
|
||||
else:
|
||||
await auth_func(credentials, conn)
|
||||
@@ -0,0 +1,100 @@
|
||||
# Copyright 2020-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MONGODB-AWS Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Type
|
||||
|
||||
import bson
|
||||
from bson.binary import Binary
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson.typings import _ReadableBuffer
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.auth_shared import MongoCredential
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
async def _authenticate_aws(credentials: MongoCredential, conn: AsyncConnection) -> None:
|
||||
"""Authenticate using MONGODB-AWS."""
|
||||
try:
|
||||
import pymongo_auth_aws # type:ignore[import]
|
||||
except ImportError as e:
|
||||
raise ConfigurationError(
|
||||
"MONGODB-AWS authentication requires pymongo-auth-aws: "
|
||||
"install with: python -m pip install 'pymongo[aws]'"
|
||||
) from e
|
||||
# Delayed import.
|
||||
from pymongo_auth_aws.auth import ( # type:ignore[import]
|
||||
set_cached_credentials,
|
||||
set_use_cached_credentials,
|
||||
)
|
||||
|
||||
set_use_cached_credentials(True)
|
||||
|
||||
if conn.max_wire_version < 9:
|
||||
raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later")
|
||||
|
||||
class AwsSaslContext(pymongo_auth_aws.AwsSaslContext): # type: ignore
|
||||
# Dependency injection:
|
||||
def binary_type(self) -> Type[Binary]:
|
||||
"""Return the bson.binary.Binary type."""
|
||||
return Binary
|
||||
|
||||
def bson_encode(self, doc: Mapping[str, Any]) -> bytes:
|
||||
"""Encode a dictionary to BSON."""
|
||||
return bson.encode(doc)
|
||||
|
||||
def bson_decode(self, data: _ReadableBuffer) -> Mapping[str, Any]:
|
||||
"""Decode BSON to a dictionary."""
|
||||
return bson.decode(data)
|
||||
|
||||
try:
|
||||
ctx = AwsSaslContext(
|
||||
pymongo_auth_aws.AwsCredential(
|
||||
credentials.username,
|
||||
credentials.password,
|
||||
credentials.mechanism_properties.aws_session_token,
|
||||
)
|
||||
)
|
||||
client_payload = ctx.step(None)
|
||||
client_first = {"saslStart": 1, "mechanism": "MONGODB-AWS", "payload": client_payload}
|
||||
server_first = await conn.command("$external", client_first)
|
||||
res = server_first
|
||||
# Limit how many times we loop to catch protocol / library issues
|
||||
for _ in range(10):
|
||||
client_payload = ctx.step(res["payload"])
|
||||
cmd = {
|
||||
"saslContinue": 1,
|
||||
"conversationId": server_first["conversationId"],
|
||||
"payload": client_payload,
|
||||
}
|
||||
res = await conn.command("$external", cmd)
|
||||
if res["done"]:
|
||||
# SASL complete.
|
||||
break
|
||||
except pymongo_auth_aws.PyMongoAuthAwsError as exc:
|
||||
# Clear the cached credentials if we hit a failure in auth.
|
||||
set_cached_credentials(None)
|
||||
# Convert to OperationFailure and include pymongo-auth-aws version.
|
||||
raise OperationFailure(
|
||||
f"{exc} (pymongo-auth-aws version {pymongo_auth_aws.__version__})"
|
||||
) from None
|
||||
except Exception:
|
||||
# Clear the cached credentials if we hit a failure in auth.
|
||||
set_cached_credentials(None)
|
||||
raise
|
||||
@@ -0,0 +1,294 @@
|
||||
# Copyright 2023-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MONGODB-OIDC Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union
|
||||
|
||||
import bson
|
||||
from bson.binary import Binary
|
||||
from pymongo._csot import remaining
|
||||
from pymongo.auth_oidc_shared import (
|
||||
CALLBACK_VERSION,
|
||||
HUMAN_CALLBACK_TIMEOUT_SECONDS,
|
||||
MACHINE_CALLBACK_TIMEOUT_SECONDS,
|
||||
TIME_BETWEEN_CALLS_SECONDS,
|
||||
OIDCCallback,
|
||||
OIDCCallbackContext,
|
||||
OIDCCallbackResult,
|
||||
OIDCIdPInfo,
|
||||
_OIDCProperties,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.auth_shared import MongoCredential
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
def _get_authenticator(
|
||||
credentials: MongoCredential, address: tuple[str, int]
|
||||
) -> _OIDCAuthenticator:
|
||||
if credentials.cache.data:
|
||||
return credentials.cache.data
|
||||
|
||||
# Extract values.
|
||||
principal_name = credentials.username
|
||||
properties = credentials.mechanism_properties
|
||||
|
||||
# Validate that the address is allowed.
|
||||
if not properties.environment:
|
||||
found = False
|
||||
allowed_hosts = properties.allowed_hosts
|
||||
for patt in allowed_hosts:
|
||||
if patt == address[0]:
|
||||
found = True
|
||||
elif patt.startswith("*.") and address[0].endswith(patt[1:]):
|
||||
found = True
|
||||
if not found:
|
||||
raise ConfigurationError(
|
||||
f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}"
|
||||
)
|
||||
|
||||
# Get or create the cache data.
|
||||
credentials.cache.data = _OIDCAuthenticator(username=principal_name, properties=properties)
|
||||
return credentials.cache.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OIDCAuthenticator:
|
||||
username: str
|
||||
properties: _OIDCProperties
|
||||
refresh_token: Optional[str] = field(default=None)
|
||||
access_token: Optional[str] = field(default=None)
|
||||
idp_info: Optional[OIDCIdPInfo] = field(default=None)
|
||||
token_gen_id: int = field(default=0)
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
last_call_time: float = field(default=0)
|
||||
|
||||
async def reauthenticate(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]:
|
||||
"""Handle a reauthenticate from the server."""
|
||||
# Invalidate the token for the connection.
|
||||
self._invalidate(conn)
|
||||
# Call the appropriate auth logic for the callback type.
|
||||
if self.properties.callback:
|
||||
return await self._authenticate_machine(conn)
|
||||
return await self._authenticate_human(conn)
|
||||
|
||||
async def authenticate(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]:
|
||||
"""Handle an initial authenticate request."""
|
||||
# First handle speculative auth.
|
||||
# If it succeeded, we are done.
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
resp = ctx.speculative_authenticate
|
||||
if resp and resp["done"]:
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
return resp
|
||||
|
||||
# If spec auth failed, call the appropriate auth logic for the callback type.
|
||||
# We cannot assume that the token is invalid, because a proxy may have been
|
||||
# involved that stripped the speculative auth information.
|
||||
if self.properties.callback:
|
||||
return await self._authenticate_machine(conn)
|
||||
return await self._authenticate_human(conn)
|
||||
|
||||
def get_spec_auth_cmd(self) -> Optional[MutableMapping[str, Any]]:
|
||||
"""Get the appropriate speculative auth command."""
|
||||
if not self.access_token:
|
||||
return None
|
||||
return self._get_start_command({"jwt": self.access_token})
|
||||
|
||||
async def _authenticate_machine(self, conn: AsyncConnection) -> Mapping[str, Any]:
|
||||
# If there is a cached access token, try to authenticate with it. If
|
||||
# authentication fails with error code 18, invalidate the access token,
|
||||
# fetch a new access token, and try to authenticate again. If authentication
|
||||
# fails for any other reason, raise the error to the user.
|
||||
if self.access_token:
|
||||
try:
|
||||
return await self._sasl_start_jwt(conn)
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
return await self._authenticate_machine(conn)
|
||||
raise
|
||||
return await self._sasl_start_jwt(conn)
|
||||
|
||||
async def _authenticate_human(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]:
|
||||
# If we have a cached access token, try a JwtStepRequest.
|
||||
# authentication fails with error code 18, invalidate the access token,
|
||||
# and try to authenticate again. If authentication fails for any other
|
||||
# reason, raise the error to the user.
|
||||
if self.access_token:
|
||||
try:
|
||||
return await self._sasl_start_jwt(conn)
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
return await self._authenticate_human(conn)
|
||||
raise
|
||||
|
||||
# If we have a cached refresh token, try a JwtStepRequest with that.
|
||||
# If authentication fails with error code 18, invalidate the access and
|
||||
# refresh tokens, and try to authenticate again. If authentication fails for
|
||||
# any other reason, raise the error to the user.
|
||||
if self.refresh_token:
|
||||
try:
|
||||
return await self._sasl_start_jwt(conn)
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
self.refresh_token = None
|
||||
return await self._authenticate_human(conn)
|
||||
raise
|
||||
|
||||
# Start a new Two-Step SASL conversation.
|
||||
# Run a PrincipalStepRequest to get the IdpInfo.
|
||||
cmd = self._get_start_command(None)
|
||||
start_resp = await self._run_command(conn, cmd)
|
||||
# Attempt to authenticate with a JwtStepRequest.
|
||||
return await self._sasl_continue_jwt(conn, start_resp)
|
||||
|
||||
def _get_access_token(self) -> Optional[str]:
|
||||
properties = self.properties
|
||||
cb: Union[None, OIDCCallback]
|
||||
resp: OIDCCallbackResult
|
||||
|
||||
is_human = properties.human_callback is not None
|
||||
if is_human and self.idp_info is None:
|
||||
return None
|
||||
|
||||
if properties.callback:
|
||||
cb = properties.callback
|
||||
if properties.human_callback:
|
||||
cb = properties.human_callback
|
||||
|
||||
prev_token = self.access_token
|
||||
if prev_token:
|
||||
return prev_token
|
||||
|
||||
if cb is None and not prev_token:
|
||||
return None
|
||||
|
||||
if not prev_token and cb is not None:
|
||||
with self.lock:
|
||||
# See if the token was changed while we were waiting for the
|
||||
# lock.
|
||||
new_token = self.access_token
|
||||
if new_token != prev_token:
|
||||
return new_token
|
||||
|
||||
# Ensure that we are waiting a min time between callback invocations.
|
||||
delta = time.time() - self.last_call_time
|
||||
if delta < TIME_BETWEEN_CALLS_SECONDS:
|
||||
time.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
|
||||
self.last_call_time = time.time()
|
||||
|
||||
if is_human:
|
||||
timeout = HUMAN_CALLBACK_TIMEOUT_SECONDS
|
||||
assert self.idp_info is not None
|
||||
else:
|
||||
timeout = int(remaining() or MACHINE_CALLBACK_TIMEOUT_SECONDS)
|
||||
context = OIDCCallbackContext(
|
||||
timeout_seconds=timeout,
|
||||
version=CALLBACK_VERSION,
|
||||
refresh_token=self.refresh_token,
|
||||
idp_info=self.idp_info,
|
||||
username=self.properties.username,
|
||||
)
|
||||
resp = cb.fetch(context)
|
||||
if not isinstance(resp, OIDCCallbackResult):
|
||||
raise ValueError("Callback result must be of type OIDCCallbackResult")
|
||||
self.refresh_token = resp.refresh_token
|
||||
self.access_token = resp.access_token
|
||||
self.token_gen_id += 1
|
||||
|
||||
return self.access_token
|
||||
|
||||
async def _run_command(
|
||||
self, conn: AsyncConnection, cmd: MutableMapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
try:
|
||||
return await conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
|
||||
except OperationFailure as e:
|
||||
if self._is_auth_error(e):
|
||||
self._invalidate(conn)
|
||||
raise
|
||||
|
||||
def _is_auth_error(self, err: Exception) -> bool:
|
||||
if not isinstance(err, OperationFailure):
|
||||
return False
|
||||
return err.code == _AUTHENTICATION_FAILURE_CODE
|
||||
|
||||
def _invalidate(self, conn: AsyncConnection) -> None:
|
||||
# Ignore the invalidation if a token gen id is given and is less than our
|
||||
# current token gen id.
|
||||
token_gen_id = conn.oidc_token_gen_id or 0
|
||||
if token_gen_id is not None and token_gen_id < self.token_gen_id:
|
||||
return
|
||||
self.access_token = None
|
||||
|
||||
async def _sasl_continue_jwt(
|
||||
self, conn: AsyncConnection, start_resp: Mapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
self.access_token = None
|
||||
self.refresh_token = None
|
||||
start_payload: dict = bson.decode(start_resp["payload"])
|
||||
if "issuer" in start_payload:
|
||||
self.idp_info = OIDCIdPInfo(**start_payload)
|
||||
access_token = self._get_access_token()
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
cmd = self._get_continue_command({"jwt": access_token}, start_resp)
|
||||
return await self._run_command(conn, cmd)
|
||||
|
||||
async def _sasl_start_jwt(self, conn: AsyncConnection) -> Mapping[str, Any]:
|
||||
access_token = self._get_access_token()
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
cmd = self._get_start_command({"jwt": access_token})
|
||||
return await self._run_command(conn, cmd)
|
||||
|
||||
def _get_start_command(self, payload: Optional[Mapping[str, Any]]) -> MutableMapping[str, Any]:
|
||||
if payload is None:
|
||||
principal_name = self.username
|
||||
if principal_name:
|
||||
payload = {"n": principal_name}
|
||||
else:
|
||||
payload = {}
|
||||
bin_payload = Binary(bson.encode(payload))
|
||||
return {"saslStart": 1, "mechanism": "MONGODB-OIDC", "payload": bin_payload}
|
||||
|
||||
def _get_continue_command(
|
||||
self, payload: Mapping[str, Any], start_resp: Mapping[str, Any]
|
||||
) -> MutableMapping[str, Any]:
|
||||
bin_payload = Binary(bson.encode(payload))
|
||||
return {
|
||||
"saslContinue": 1,
|
||||
"payload": bin_payload,
|
||||
"conversationId": start_resp["conversationId"],
|
||||
}
|
||||
|
||||
|
||||
async def _authenticate_oidc(
|
||||
credentials: MongoCredential, conn: AsyncConnection, reauthenticate: bool
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""Authenticate using MONGODB-OIDC."""
|
||||
authenticator = _get_authenticator(credentials, conn.address)
|
||||
if reauthenticate:
|
||||
return await authenticator.reauthenticate(conn)
|
||||
else:
|
||||
return await authenticator.authenticate(conn)
|
||||
738
venv/lib/python3.11/site-packages/pymongo/asynchronous/bulk.py
Normal file
738
venv/lib/python3.11/site-packages/pymongo/asynchronous/bulk.py
Normal file
@@ -0,0 +1,738 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The bulk write operations interface.
|
||||
|
||||
.. versionadded:: 2.7
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
import logging
|
||||
from collections.abc import MutableMapping
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Iterator,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo import _csot, common
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern
|
||||
from pymongo.asynchronous.helpers import _handle_reauth
|
||||
from pymongo.bulk_shared import (
|
||||
_COMMANDS,
|
||||
_DELETE_ALL,
|
||||
_merge_command,
|
||||
_raise_bulk_write_error,
|
||||
_Run,
|
||||
)
|
||||
from pymongo.common import (
|
||||
validate_is_document_type,
|
||||
validate_ok_for_replace,
|
||||
validate_ok_for_update,
|
||||
)
|
||||
from pymongo.errors import (
|
||||
ConfigurationError,
|
||||
InvalidOperation,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
|
||||
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.message import (
|
||||
_DELETE,
|
||||
_INSERT,
|
||||
_UPDATE,
|
||||
_BulkWriteContext,
|
||||
_convert_exception,
|
||||
_convert_write_result,
|
||||
_EncryptedBulkWriteContext,
|
||||
_randint,
|
||||
)
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class _AsyncBulk:
|
||||
"""The private guts of the bulk write API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: AsyncCollection[_DocumentType],
|
||||
ordered: bool,
|
||||
bypass_document_validation: bool,
|
||||
comment: Optional[str] = None,
|
||||
let: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""Initialize a _AsyncBulk instance."""
|
||||
self.collection = collection.with_options(
|
||||
codec_options=collection.codec_options._replace(
|
||||
unicode_decode_error_handler="replace", document_class=dict
|
||||
)
|
||||
)
|
||||
self.let = let
|
||||
if self.let is not None:
|
||||
common.validate_is_document_type("let", self.let)
|
||||
self.comment: Optional[str] = comment
|
||||
self.ordered = ordered
|
||||
self.ops: list[tuple[int, Mapping[str, Any]]] = []
|
||||
self.executed = False
|
||||
self.bypass_doc_val = bypass_document_validation
|
||||
self.uses_collation = False
|
||||
self.uses_array_filters = False
|
||||
self.uses_hint_update = False
|
||||
self.uses_hint_delete = False
|
||||
self.is_retryable = True
|
||||
self.retrying = False
|
||||
self.started_retryable_write = False
|
||||
# Extra state so that we know where to pick up on a retry attempt.
|
||||
self.current_run = None
|
||||
self.next_run = None
|
||||
self.is_encrypted = False
|
||||
|
||||
@property
|
||||
def bulk_ctx_class(self) -> Type[_BulkWriteContext]:
|
||||
encrypter = self.collection.database.client._encrypter
|
||||
if encrypter and not encrypter._bypass_auto_encryption:
|
||||
self.is_encrypted = True
|
||||
return _EncryptedBulkWriteContext
|
||||
else:
|
||||
self.is_encrypted = False
|
||||
return _BulkWriteContext
|
||||
|
||||
def add_insert(self, document: _DocumentOut) -> None:
|
||||
"""Add an insert document to the list of ops."""
|
||||
validate_is_document_type("document", document)
|
||||
# Generate ObjectId client side.
|
||||
if not (isinstance(document, RawBSONDocument) or "_id" in document):
|
||||
document["_id"] = ObjectId()
|
||||
self.ops.append((_INSERT, document))
|
||||
|
||||
def add_update(
|
||||
self,
|
||||
selector: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
multi: bool = False,
|
||||
upsert: bool = False,
|
||||
collation: Optional[Mapping[str, Any]] = None,
|
||||
array_filters: Optional[list[Mapping[str, Any]]] = None,
|
||||
hint: Union[str, dict[str, Any], None] = None,
|
||||
) -> None:
|
||||
"""Create an update document and add it to the list of ops."""
|
||||
validate_ok_for_update(update)
|
||||
cmd: dict[str, Any] = dict( # noqa: C406
|
||||
[("q", selector), ("u", update), ("multi", multi), ("upsert", upsert)]
|
||||
)
|
||||
if collation is not None:
|
||||
self.uses_collation = True
|
||||
cmd["collation"] = collation
|
||||
if array_filters is not None:
|
||||
self.uses_array_filters = True
|
||||
cmd["arrayFilters"] = array_filters
|
||||
if hint is not None:
|
||||
self.uses_hint_update = True
|
||||
cmd["hint"] = hint
|
||||
if multi:
|
||||
# A bulk_write containing an update_many is not retryable.
|
||||
self.is_retryable = False
|
||||
self.ops.append((_UPDATE, cmd))
|
||||
|
||||
def add_replace(
|
||||
self,
|
||||
selector: Mapping[str, Any],
|
||||
replacement: Mapping[str, Any],
|
||||
upsert: bool = False,
|
||||
collation: Optional[Mapping[str, Any]] = None,
|
||||
hint: Union[str, dict[str, Any], None] = None,
|
||||
) -> None:
|
||||
"""Create a replace document and add it to the list of ops."""
|
||||
validate_ok_for_replace(replacement)
|
||||
cmd = {"q": selector, "u": replacement, "multi": False, "upsert": upsert}
|
||||
if collation is not None:
|
||||
self.uses_collation = True
|
||||
cmd["collation"] = collation
|
||||
if hint is not None:
|
||||
self.uses_hint_update = True
|
||||
cmd["hint"] = hint
|
||||
self.ops.append((_UPDATE, cmd))
|
||||
|
||||
def add_delete(
|
||||
self,
|
||||
selector: Mapping[str, Any],
|
||||
limit: int,
|
||||
collation: Optional[Mapping[str, Any]] = None,
|
||||
hint: Union[str, dict[str, Any], None] = None,
|
||||
) -> None:
|
||||
"""Create a delete document and add it to the list of ops."""
|
||||
cmd = {"q": selector, "limit": limit}
|
||||
if collation is not None:
|
||||
self.uses_collation = True
|
||||
cmd["collation"] = collation
|
||||
if hint is not None:
|
||||
self.uses_hint_delete = True
|
||||
cmd["hint"] = hint
|
||||
if limit == _DELETE_ALL:
|
||||
# A bulk_write containing a delete_many is not retryable.
|
||||
self.is_retryable = False
|
||||
self.ops.append((_DELETE, cmd))
|
||||
|
||||
def gen_ordered(self) -> Iterator[Optional[_Run]]:
|
||||
"""Generate batches of operations, batched by type of
|
||||
operation, in the order **provided**.
|
||||
"""
|
||||
run = None
|
||||
for idx, (op_type, operation) in enumerate(self.ops):
|
||||
if run is None:
|
||||
run = _Run(op_type)
|
||||
elif run.op_type != op_type:
|
||||
yield run
|
||||
run = _Run(op_type)
|
||||
run.add(idx, operation)
|
||||
yield run
|
||||
|
||||
def gen_unordered(self) -> Iterator[_Run]:
|
||||
"""Generate batches of operations, batched by type of
|
||||
operation, in arbitrary order.
|
||||
"""
|
||||
operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)]
|
||||
for idx, (op_type, operation) in enumerate(self.ops):
|
||||
operations[op_type].add(idx, operation)
|
||||
|
||||
for run in operations:
|
||||
if run.ops:
|
||||
yield run
|
||||
|
||||
@_handle_reauth
|
||||
async def write_command(
|
||||
self,
|
||||
bwc: _BulkWriteContext,
|
||||
cmd: MutableMapping[str, Any],
|
||||
request_id: int,
|
||||
msg: bytes,
|
||||
docs: list[Mapping[str, Any]],
|
||||
client: AsyncMongoClient,
|
||||
) -> dict[str, Any]:
|
||||
"""A proxy for SocketInfo.write_command that handles event publishing."""
|
||||
cmd[bwc.field] = docs
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
command=cmd,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._start(cmd, request_id, docs)
|
||||
try:
|
||||
reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc]
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
durationMS=duration,
|
||||
reply=reply,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
|
||||
await client._process_response(reply, bwc.session) # type: ignore[arg-type]
|
||||
except Exception as exc:
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if isinstance(exc, (NotPrimaryError, OperationFailure)):
|
||||
failure: _DocumentOut = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
|
||||
if bwc.publish:
|
||||
bwc._fail(request_id, failure, duration)
|
||||
# Process the response from the server.
|
||||
if isinstance(exc, (NotPrimaryError, OperationFailure)):
|
||||
await client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
|
||||
raise
|
||||
return reply # type: ignore[return-value]
|
||||
|
||||
async def unack_write(
|
||||
self,
|
||||
bwc: _BulkWriteContext,
|
||||
cmd: MutableMapping[str, Any],
|
||||
request_id: int,
|
||||
msg: bytes,
|
||||
max_doc_size: int,
|
||||
docs: list[Mapping[str, Any]],
|
||||
client: AsyncMongoClient,
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""A proxy for AsyncConnection.unack_write that handles event publishing."""
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
command=cmd,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
cmd = bwc._start(cmd, request_id, docs)
|
||||
try:
|
||||
result = await bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override]
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if result is not None:
|
||||
reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type]
|
||||
else:
|
||||
# Comply with APM spec.
|
||||
reply = {"ok": 1}
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
durationMS=duration,
|
||||
reply=reply,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._succeed(request_id, reply, duration)
|
||||
except Exception as exc:
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if isinstance(exc, OperationFailure):
|
||||
failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type]
|
||||
elif isinstance(exc, NotPrimaryError):
|
||||
failure = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
if bwc.publish:
|
||||
assert bwc.start_time is not None
|
||||
bwc._fail(request_id, failure, duration)
|
||||
raise
|
||||
return result # type: ignore[return-value]
|
||||
|
||||
async def _execute_batch_unack(
|
||||
self,
|
||||
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
|
||||
cmd: dict[str, Any],
|
||||
ops: list[Mapping[str, Any]],
|
||||
client: AsyncMongoClient,
|
||||
) -> list[Mapping[str, Any]]:
|
||||
if self.is_encrypted:
|
||||
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
|
||||
await bwc.conn.command( # type: ignore[misc]
|
||||
bwc.db_name,
|
||||
batched_cmd, # type: ignore[arg-type]
|
||||
write_concern=WriteConcern(w=0),
|
||||
session=bwc.session, # type: ignore[arg-type]
|
||||
client=client, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
request_id, msg, to_send = bwc.batch_command(cmd, ops)
|
||||
# Though this isn't strictly a "legacy" write, the helper
|
||||
# handles publishing commands and sending our message
|
||||
# without receiving a result. Send 0 for max_doc_size
|
||||
# to disable size checking. Size checking is handled while
|
||||
# the documents are encoded to BSON.
|
||||
await self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) # type: ignore[arg-type]
|
||||
|
||||
return to_send
|
||||
|
||||
async def _execute_batch(
|
||||
self,
|
||||
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
|
||||
cmd: dict[str, Any],
|
||||
ops: list[Mapping[str, Any]],
|
||||
client: AsyncMongoClient,
|
||||
) -> tuple[dict[str, Any], list[Mapping[str, Any]]]:
|
||||
if self.is_encrypted:
|
||||
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
|
||||
result = await bwc.conn.command( # type: ignore[misc]
|
||||
bwc.db_name,
|
||||
batched_cmd, # type: ignore[arg-type]
|
||||
codec_options=bwc.codec,
|
||||
session=bwc.session, # type: ignore[arg-type]
|
||||
client=client, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
request_id, msg, to_send = bwc.batch_command(cmd, ops)
|
||||
result = await self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type]
|
||||
|
||||
return result, to_send # type: ignore[return-value]
|
||||
|
||||
async def _execute_command(
|
||||
self,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
session: Optional[AsyncClientSession],
|
||||
conn: AsyncConnection,
|
||||
op_id: int,
|
||||
retryable: bool,
|
||||
full_result: MutableMapping[str, Any],
|
||||
final_write_concern: Optional[WriteConcern] = None,
|
||||
) -> None:
|
||||
db_name = self.collection.database.name
|
||||
client = self.collection.database.client
|
||||
listeners = client._event_listeners
|
||||
|
||||
if not self.current_run:
|
||||
self.current_run = next(generator)
|
||||
self.next_run = None
|
||||
run = self.current_run
|
||||
|
||||
# AsyncConnection.command validates the session, but we use
|
||||
# AsyncConnection.write_command
|
||||
conn.validate_session(client, session)
|
||||
last_run = False
|
||||
|
||||
while run:
|
||||
if not self.retrying:
|
||||
self.next_run = next(generator, None)
|
||||
if self.next_run is None:
|
||||
last_run = True
|
||||
|
||||
cmd_name = _COMMANDS[run.op_type]
|
||||
bwc = self.bulk_ctx_class(
|
||||
db_name,
|
||||
cmd_name,
|
||||
conn,
|
||||
op_id,
|
||||
listeners,
|
||||
session,
|
||||
run.op_type,
|
||||
self.collection.codec_options,
|
||||
)
|
||||
|
||||
while run.idx_offset < len(run.ops):
|
||||
# If this is the last possible operation, use the
|
||||
# final write concern.
|
||||
if last_run and (len(run.ops) - run.idx_offset) == 1:
|
||||
write_concern = final_write_concern or write_concern
|
||||
|
||||
cmd = {cmd_name: self.collection.name, "ordered": self.ordered}
|
||||
if self.comment:
|
||||
cmd["comment"] = self.comment
|
||||
_csot.apply_write_concern(cmd, write_concern)
|
||||
if self.bypass_doc_val:
|
||||
cmd["bypassDocumentValidation"] = True
|
||||
if self.let is not None and run.op_type in (_DELETE, _UPDATE):
|
||||
cmd["let"] = self.let
|
||||
if session:
|
||||
# Start a new retryable write unless one was already
|
||||
# started for this command.
|
||||
if retryable and not self.started_retryable_write:
|
||||
session._start_retryable_write()
|
||||
self.started_retryable_write = True
|
||||
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
|
||||
conn.send_cluster_time(cmd, session, client)
|
||||
conn.add_server_api(cmd)
|
||||
# CSOT: apply timeout before encoding the command.
|
||||
conn.apply_timeout(client, cmd)
|
||||
ops = islice(run.ops, run.idx_offset, None)
|
||||
|
||||
# Run as many ops as possible in one command.
|
||||
if write_concern.acknowledged:
|
||||
result, to_send = await self._execute_batch(bwc, cmd, ops, client)
|
||||
|
||||
# Retryable writeConcernErrors halt the execution of this run.
|
||||
wce = result.get("writeConcernError", {})
|
||||
if wce.get("code", 0) in _RETRYABLE_ERROR_CODES:
|
||||
# Synthesize the full bulk result without modifying the
|
||||
# current one because this write operation may be retried.
|
||||
full = copy.deepcopy(full_result)
|
||||
_merge_command(run, full, run.idx_offset, result)
|
||||
_raise_bulk_write_error(full)
|
||||
|
||||
_merge_command(run, full_result, run.idx_offset, result)
|
||||
|
||||
# We're no longer in a retry once a command succeeds.
|
||||
self.retrying = False
|
||||
self.started_retryable_write = False
|
||||
|
||||
if self.ordered and "writeErrors" in result:
|
||||
break
|
||||
else:
|
||||
to_send = await self._execute_batch_unack(bwc, cmd, ops, client)
|
||||
|
||||
run.idx_offset += len(to_send)
|
||||
|
||||
# We're supposed to continue if errors are
|
||||
# at the write concern level (e.g. wtimeout)
|
||||
if self.ordered and full_result["writeErrors"]:
|
||||
break
|
||||
# Reset our state
|
||||
self.current_run = run = self.next_run
|
||||
|
||||
async def execute_command(
|
||||
self,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
session: Optional[AsyncClientSession],
|
||||
operation: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute using write commands."""
|
||||
# nModified is only reported for write commands, not legacy ops.
|
||||
full_result = {
|
||||
"writeErrors": [],
|
||||
"writeConcernErrors": [],
|
||||
"nInserted": 0,
|
||||
"nUpserted": 0,
|
||||
"nMatched": 0,
|
||||
"nModified": 0,
|
||||
"nRemoved": 0,
|
||||
"upserted": [],
|
||||
}
|
||||
op_id = _randint()
|
||||
|
||||
async def retryable_bulk(
|
||||
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable: bool
|
||||
) -> None:
|
||||
await self._execute_command(
|
||||
generator,
|
||||
write_concern,
|
||||
session,
|
||||
conn,
|
||||
op_id,
|
||||
retryable,
|
||||
full_result,
|
||||
)
|
||||
|
||||
client = self.collection.database.client
|
||||
_ = await client._retryable_write(
|
||||
self.is_retryable,
|
||||
retryable_bulk,
|
||||
session,
|
||||
operation,
|
||||
bulk=self, # type: ignore[arg-type]
|
||||
operation_id=op_id,
|
||||
)
|
||||
|
||||
if full_result["writeErrors"] or full_result["writeConcernErrors"]:
|
||||
_raise_bulk_write_error(full_result)
|
||||
return full_result
|
||||
|
||||
async def execute_op_msg_no_results(
|
||||
self, conn: AsyncConnection, generator: Iterator[Any]
|
||||
) -> None:
|
||||
"""Execute write commands with OP_MSG and w=0 writeConcern, unordered."""
|
||||
db_name = self.collection.database.name
|
||||
client = self.collection.database.client
|
||||
listeners = client._event_listeners
|
||||
op_id = _randint()
|
||||
|
||||
if not self.current_run:
|
||||
self.current_run = next(generator)
|
||||
run = self.current_run
|
||||
|
||||
while run:
|
||||
cmd_name = _COMMANDS[run.op_type]
|
||||
bwc = self.bulk_ctx_class(
|
||||
db_name,
|
||||
cmd_name,
|
||||
conn,
|
||||
op_id,
|
||||
listeners,
|
||||
None,
|
||||
run.op_type,
|
||||
self.collection.codec_options,
|
||||
)
|
||||
|
||||
while run.idx_offset < len(run.ops):
|
||||
cmd = {
|
||||
cmd_name: self.collection.name,
|
||||
"ordered": False,
|
||||
"writeConcern": {"w": 0},
|
||||
}
|
||||
conn.add_server_api(cmd)
|
||||
ops = islice(run.ops, run.idx_offset, None)
|
||||
# Run as many ops as possible.
|
||||
to_send = await self._execute_batch_unack(bwc, cmd, ops, client)
|
||||
run.idx_offset += len(to_send)
|
||||
self.current_run = run = next(generator, None)
|
||||
|
||||
async def execute_command_no_results(
|
||||
self,
|
||||
conn: AsyncConnection,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
) -> None:
|
||||
"""Execute write commands with OP_MSG and w=0 WriteConcern, ordered."""
|
||||
full_result = {
|
||||
"writeErrors": [],
|
||||
"writeConcernErrors": [],
|
||||
"nInserted": 0,
|
||||
"nUpserted": 0,
|
||||
"nMatched": 0,
|
||||
"nModified": 0,
|
||||
"nRemoved": 0,
|
||||
"upserted": [],
|
||||
}
|
||||
# Ordered bulk writes have to be acknowledged so that we stop
|
||||
# processing at the first error, even when the application
|
||||
# specified unacknowledged writeConcern.
|
||||
initial_write_concern = WriteConcern()
|
||||
op_id = _randint()
|
||||
try:
|
||||
await self._execute_command(
|
||||
generator,
|
||||
initial_write_concern,
|
||||
None,
|
||||
conn,
|
||||
op_id,
|
||||
False,
|
||||
full_result,
|
||||
write_concern,
|
||||
)
|
||||
except OperationFailure:
|
||||
pass
|
||||
|
||||
async def execute_no_results(
|
||||
self,
|
||||
conn: AsyncConnection,
|
||||
generator: Iterator[Any],
|
||||
write_concern: WriteConcern,
|
||||
) -> None:
|
||||
"""Execute all operations, returning no results (w=0)."""
|
||||
if self.uses_collation:
|
||||
raise ConfigurationError("Collation is unsupported for unacknowledged writes.")
|
||||
if self.uses_array_filters:
|
||||
raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.")
|
||||
# Guard against unsupported unacknowledged writes.
|
||||
unack = write_concern and not write_concern.acknowledged
|
||||
if unack and self.uses_hint_delete and conn.max_wire_version < 9:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
|
||||
)
|
||||
if unack and self.uses_hint_update and conn.max_wire_version < 8:
|
||||
raise ConfigurationError(
|
||||
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
|
||||
)
|
||||
# Cannot have both unacknowledged writes and bypass document validation.
|
||||
if self.bypass_doc_val:
|
||||
raise OperationFailure(
|
||||
"Cannot set bypass_document_validation with unacknowledged write concern"
|
||||
)
|
||||
|
||||
if self.ordered:
|
||||
return await self.execute_command_no_results(conn, generator, write_concern)
|
||||
return await self.execute_op_msg_no_results(conn, generator)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
write_concern: WriteConcern,
|
||||
session: Optional[AsyncClientSession],
|
||||
operation: str,
|
||||
) -> Any:
|
||||
"""Execute operations."""
|
||||
if not self.ops:
|
||||
raise InvalidOperation("No operations to execute")
|
||||
if self.executed:
|
||||
raise InvalidOperation("Bulk operations can only be executed once.")
|
||||
self.executed = True
|
||||
write_concern = write_concern or self.collection.write_concern
|
||||
session = _validate_session_write_concern(session, write_concern)
|
||||
|
||||
if self.ordered:
|
||||
generator = self.gen_ordered()
|
||||
else:
|
||||
generator = self.gen_unordered()
|
||||
|
||||
client = self.collection.database.client
|
||||
if not write_concern.acknowledged:
|
||||
async with await client._conn_for_writes(session, operation) as connection:
|
||||
await self.execute_no_results(connection, generator, write_concern)
|
||||
return None
|
||||
else:
|
||||
return await self.execute_command(generator, write_concern, session, operation)
|
||||
@@ -0,0 +1,498 @@
|
||||
# Copyright 2017 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Watch changes on a collection, a database, or the entire cluster."""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union
|
||||
|
||||
from bson import CodecOptions, _bson_to_dict
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot, common
|
||||
from pymongo.asynchronous.aggregation import (
|
||||
_AggregationCommand,
|
||||
_CollectionAggregationCommand,
|
||||
_DatabaseAggregationCommand,
|
||||
)
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.errors import (
|
||||
ConnectionFailure,
|
||||
CursorNotFound,
|
||||
InvalidOperation,
|
||||
OperationFailure,
|
||||
PyMongoError,
|
||||
)
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# The change streams spec considers the following server errors from the
|
||||
# getMore command non-resumable. All other getMore errors are resumable.
|
||||
_RESUMABLE_GETMORE_ERRORS = frozenset(
|
||||
[
|
||||
6, # HostUnreachable
|
||||
7, # HostNotFound
|
||||
89, # NetworkTimeout
|
||||
91, # ShutdownInProgress
|
||||
189, # PrimarySteppedDown
|
||||
262, # ExceededTimeLimit
|
||||
9001, # SocketException
|
||||
10107, # NotWritablePrimary
|
||||
11600, # InterruptedAtShutdown
|
||||
11602, # InterruptedDueToReplStateChange
|
||||
13435, # NotPrimaryNoSecondaryOk
|
||||
13436, # NotPrimaryOrSecondary
|
||||
63, # StaleShardVersion
|
||||
150, # StaleEpoch
|
||||
13388, # StaleConfig
|
||||
234, # RetryChangeStream
|
||||
133, # FailedToSatisfyReadPreference
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
|
||||
|
||||
def _resumable(exc: PyMongoError) -> bool:
|
||||
"""Return True if given a resumable change stream error."""
|
||||
if isinstance(exc, (ConnectionFailure, CursorNotFound)):
|
||||
return True
|
||||
if isinstance(exc, OperationFailure):
|
||||
if exc._max_wire_version is None:
|
||||
return False
|
||||
return (
|
||||
exc._max_wire_version >= 9 and exc.has_error_label("ResumableChangeStreamError")
|
||||
) or (exc._max_wire_version < 9 and exc.code in _RESUMABLE_GETMORE_ERRORS)
|
||||
return False
|
||||
|
||||
|
||||
class AsyncChangeStream(Generic[_DocumentType]):
|
||||
"""The internal abstract base class for change stream cursors.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
:meth:`pymongo.asynchronous.collection.AsyncCollection.watch`,
|
||||
:meth:`pymongo.asynchronous.database.AsyncDatabase.watch`, or
|
||||
:meth:`pymongo.asynchronous.mongo_client.AsyncMongoClient.watch` instead.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
.. seealso:: The MongoDB documentation on `changeStreams <https://mongodb.com/docs/manual/changeStreams/>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[
|
||||
AsyncMongoClient[_DocumentType],
|
||||
AsyncDatabase[_DocumentType],
|
||||
AsyncCollection[_DocumentType],
|
||||
],
|
||||
pipeline: Optional[_Pipeline],
|
||||
full_document: Optional[str],
|
||||
resume_after: Optional[Mapping[str, Any]],
|
||||
max_await_time_ms: Optional[int],
|
||||
batch_size: Optional[int],
|
||||
collation: Optional[_CollationIn],
|
||||
start_at_operation_time: Optional[Timestamp],
|
||||
session: Optional[AsyncClientSession],
|
||||
start_after: Optional[Mapping[str, Any]],
|
||||
comment: Optional[Any] = None,
|
||||
full_document_before_change: Optional[str] = None,
|
||||
show_expanded_events: Optional[bool] = None,
|
||||
) -> None:
|
||||
if pipeline is None:
|
||||
pipeline = []
|
||||
pipeline = common.validate_list("pipeline", pipeline)
|
||||
common.validate_string_or_none("full_document", full_document)
|
||||
validate_collation_or_none(collation)
|
||||
common.validate_non_negative_integer_or_none("batchSize", batch_size)
|
||||
|
||||
self._decode_custom = False
|
||||
self._orig_codec_options: CodecOptions[_DocumentType] = target.codec_options
|
||||
if target.codec_options.type_registry._decoder_map:
|
||||
self._decode_custom = True
|
||||
# Keep the type registry so that we support encoding custom types
|
||||
# in the pipeline.
|
||||
self._target = target.with_options( # type: ignore
|
||||
codec_options=target.codec_options.with_options(document_class=RawBSONDocument)
|
||||
)
|
||||
else:
|
||||
self._target = target
|
||||
|
||||
self._pipeline = copy.deepcopy(pipeline)
|
||||
self._full_document = full_document
|
||||
self._full_document_before_change = full_document_before_change
|
||||
self._uses_start_after = start_after is not None
|
||||
self._uses_resume_after = resume_after is not None
|
||||
self._resume_token = copy.deepcopy(start_after or resume_after)
|
||||
self._max_await_time_ms = max_await_time_ms
|
||||
self._batch_size = batch_size
|
||||
self._collation = collation
|
||||
self._start_at_operation_time = start_at_operation_time
|
||||
self._session = session
|
||||
self._comment = comment
|
||||
self._closed = False
|
||||
self._timeout = self._target._timeout
|
||||
self._show_expanded_events = show_expanded_events
|
||||
|
||||
async def _initialize_cursor(self) -> None:
|
||||
# Initialize cursor.
|
||||
self._cursor = await self._create_cursor()
|
||||
|
||||
@property
|
||||
def _aggregation_command_class(self) -> Type[_AggregationCommand]:
|
||||
"""The aggregation command class to be used."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _client(self) -> AsyncMongoClient:
|
||||
"""The client against which the aggregation commands for
|
||||
this AsyncChangeStream will be run.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _change_stream_options(self) -> dict[str, Any]:
|
||||
"""Return the options dict for the $changeStream pipeline stage."""
|
||||
options: dict[str, Any] = {}
|
||||
if self._full_document is not None:
|
||||
options["fullDocument"] = self._full_document
|
||||
|
||||
if self._full_document_before_change is not None:
|
||||
options["fullDocumentBeforeChange"] = self._full_document_before_change
|
||||
|
||||
resume_token = self.resume_token
|
||||
if resume_token is not None:
|
||||
if self._uses_start_after:
|
||||
options["startAfter"] = resume_token
|
||||
else:
|
||||
options["resumeAfter"] = resume_token
|
||||
|
||||
elif self._start_at_operation_time is not None:
|
||||
options["startAtOperationTime"] = self._start_at_operation_time
|
||||
|
||||
if self._show_expanded_events:
|
||||
options["showExpandedEvents"] = self._show_expanded_events
|
||||
|
||||
return options
|
||||
|
||||
def _command_options(self) -> dict[str, Any]:
|
||||
"""Return the options dict for the aggregation command."""
|
||||
options = {}
|
||||
if self._max_await_time_ms is not None:
|
||||
options["maxAwaitTimeMS"] = self._max_await_time_ms
|
||||
if self._batch_size is not None:
|
||||
options["batchSize"] = self._batch_size
|
||||
return options
|
||||
|
||||
def _aggregation_pipeline(self) -> list[dict[str, Any]]:
|
||||
"""Return the full aggregation pipeline for this AsyncChangeStream."""
|
||||
options = self._change_stream_options()
|
||||
full_pipeline: list = [{"$changeStream": options}]
|
||||
full_pipeline.extend(self._pipeline)
|
||||
return full_pipeline
|
||||
|
||||
def _process_result(self, result: Mapping[str, Any], conn: AsyncConnection) -> None:
|
||||
"""Callback that caches the postBatchResumeToken or
|
||||
startAtOperationTime from a changeStream aggregate command response
|
||||
containing an empty batch of change documents.
|
||||
|
||||
This is implemented as a callback because we need access to the wire
|
||||
version in order to determine whether to cache this value.
|
||||
"""
|
||||
if not result["cursor"]["firstBatch"]:
|
||||
if "postBatchResumeToken" in result["cursor"]:
|
||||
self._resume_token = result["cursor"]["postBatchResumeToken"]
|
||||
elif (
|
||||
self._start_at_operation_time is None
|
||||
and self._uses_resume_after is False
|
||||
and self._uses_start_after is False
|
||||
and conn.max_wire_version >= 7
|
||||
):
|
||||
self._start_at_operation_time = result.get("operationTime")
|
||||
# PYTHON-2181: informative error on missing operationTime.
|
||||
if self._start_at_operation_time is None:
|
||||
raise OperationFailure(
|
||||
"Expected field 'operationTime' missing from command "
|
||||
f"response : {result!r}"
|
||||
)
|
||||
|
||||
async def _run_aggregation_cmd(
|
||||
self, session: Optional[AsyncClientSession], explicit_session: bool
|
||||
) -> AsyncCommandCursor:
|
||||
"""Run the full aggregation pipeline for this AsyncChangeStream and return
|
||||
the corresponding AsyncCommandCursor.
|
||||
"""
|
||||
cmd = self._aggregation_command_class(
|
||||
self._target,
|
||||
AsyncCommandCursor,
|
||||
self._aggregation_pipeline(),
|
||||
self._command_options(),
|
||||
explicit_session,
|
||||
result_processor=self._process_result,
|
||||
comment=self._comment,
|
||||
)
|
||||
return await self._client._retryable_read(
|
||||
cmd.get_cursor,
|
||||
self._target._read_preference_for(session),
|
||||
session,
|
||||
operation=_Op.AGGREGATE,
|
||||
)
|
||||
|
||||
async def _create_cursor(self) -> AsyncCommandCursor:
|
||||
async with self._client._tmp_session(self._session, close=False) as s:
|
||||
return await self._run_aggregation_cmd(
|
||||
session=s, explicit_session=self._session is not None
|
||||
)
|
||||
|
||||
async def _resume(self) -> None:
|
||||
"""Reestablish this change stream after a resumable error."""
|
||||
try:
|
||||
await self._cursor.close()
|
||||
except PyMongoError:
|
||||
pass
|
||||
self._cursor = await self._create_cursor()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close this AsyncChangeStream."""
|
||||
self._closed = True
|
||||
await self._cursor.close()
|
||||
|
||||
def __aiter__(self) -> AsyncChangeStream[_DocumentType]:
|
||||
return self
|
||||
|
||||
@property
|
||||
def resume_token(self) -> Optional[Mapping[str, Any]]:
|
||||
"""The cached resume token that will be used to resume after the most
|
||||
recently returned change.
|
||||
|
||||
.. versionadded:: 3.9
|
||||
"""
|
||||
return copy.deepcopy(self._resume_token)
|
||||
|
||||
@_csot.apply
|
||||
async def next(self) -> _DocumentType:
|
||||
"""Advance the cursor.
|
||||
|
||||
This method blocks until the next change document is returned or an
|
||||
unrecoverable error is raised. This method is used when iterating over
|
||||
all changes in the cursor. For example::
|
||||
|
||||
try:
|
||||
resume_token = None
|
||||
pipeline = [{'$match': {'operationType': 'insert'}}]
|
||||
async with await db.collection.watch(pipeline) as stream:
|
||||
async for insert_change in stream:
|
||||
print(insert_change)
|
||||
resume_token = stream.resume_token
|
||||
except pymongo.errors.PyMongoError:
|
||||
# The AsyncChangeStream encountered an unrecoverable error or the
|
||||
# resume attempt failed to recreate the cursor.
|
||||
if resume_token is None:
|
||||
# There is no usable resume token because there was a
|
||||
# failure during AsyncChangeStream initialization.
|
||||
logging.error('...')
|
||||
else:
|
||||
# Use the interrupted AsyncChangeStream's resume token to create
|
||||
# a new AsyncChangeStream. The new stream will continue from the
|
||||
# last seen insert change without missing any events.
|
||||
async with await db.collection.watch(
|
||||
pipeline, resume_after=resume_token) as stream:
|
||||
async for insert_change in stream:
|
||||
print(insert_change)
|
||||
|
||||
Raises :exc:`StopIteration` if this AsyncChangeStream is closed.
|
||||
"""
|
||||
while self.alive:
|
||||
doc = await self.try_next()
|
||||
if doc is not None:
|
||||
return doc
|
||||
|
||||
raise StopAsyncIteration
|
||||
|
||||
__anext__ = next
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
"""Does this cursor have the potential to return more data?
|
||||
|
||||
.. note:: Even if :attr:`alive` is ``True``, :meth:`next` can raise
|
||||
:exc:`StopIteration` and :meth:`try_next` can return ``None``.
|
||||
|
||||
.. versionadded:: 3.8
|
||||
"""
|
||||
return not self._closed
|
||||
|
||||
@_csot.apply
|
||||
async def try_next(self) -> Optional[_DocumentType]:
|
||||
"""Advance the cursor without blocking indefinitely.
|
||||
|
||||
This method returns the next change document without waiting
|
||||
indefinitely for the next change. For example::
|
||||
|
||||
async with await db.collection.watch() as stream:
|
||||
while stream.alive:
|
||||
change = await stream.try_next()
|
||||
# Note that the AsyncChangeStream's resume token may be updated
|
||||
# even when no changes are returned.
|
||||
print("Current resume token: %r" % (stream.resume_token,))
|
||||
if change is not None:
|
||||
print("Change document: %r" % (change,))
|
||||
continue
|
||||
# We end up here when there are no recent changes.
|
||||
# Sleep for a while before trying again to avoid flooding
|
||||
# the server with getMore requests when no changes are
|
||||
# available.
|
||||
asyncio.sleep(10)
|
||||
|
||||
If no change document is cached locally then this method runs a single
|
||||
getMore command. If the getMore yields any documents, the next
|
||||
document is returned, otherwise, if the getMore returns no documents
|
||||
(because there have been no changes) then ``None`` is returned.
|
||||
|
||||
:return: The next change document or ``None`` when no document is available
|
||||
after running a single getMore or when the cursor is closed.
|
||||
|
||||
.. versionadded:: 3.8
|
||||
"""
|
||||
if not self._closed and not self._cursor.alive:
|
||||
await self._resume()
|
||||
|
||||
# Attempt to get the next change with at most one getMore and at most
|
||||
# one resume attempt.
|
||||
try:
|
||||
try:
|
||||
change = await self._cursor._try_next(True)
|
||||
except PyMongoError as exc:
|
||||
if not _resumable(exc):
|
||||
raise
|
||||
await self._resume()
|
||||
change = await self._cursor._try_next(False)
|
||||
except PyMongoError as exc:
|
||||
# Close the stream after a fatal error.
|
||||
if not _resumable(exc) and not exc.timeout:
|
||||
await self.close()
|
||||
raise
|
||||
except Exception:
|
||||
await self.close()
|
||||
raise
|
||||
|
||||
# Check if the cursor was invalidated.
|
||||
if not self._cursor.alive:
|
||||
self._closed = True
|
||||
|
||||
# If no changes are available.
|
||||
if change is None:
|
||||
# We have either iterated over all documents in the cursor,
|
||||
# OR the most-recently returned batch is empty. In either case,
|
||||
# update the cached resume token with the postBatchResumeToken if
|
||||
# one was returned. We also clear the startAtOperationTime.
|
||||
if self._cursor._post_batch_resume_token is not None:
|
||||
self._resume_token = self._cursor._post_batch_resume_token
|
||||
self._start_at_operation_time = None
|
||||
return change
|
||||
|
||||
# Else, changes are available.
|
||||
try:
|
||||
resume_token = change["_id"]
|
||||
except KeyError:
|
||||
await self.close()
|
||||
raise InvalidOperation(
|
||||
"Cannot provide resume functionality when the resume token is missing."
|
||||
) from None
|
||||
|
||||
# If this is the last change document from the current batch, cache the
|
||||
# postBatchResumeToken.
|
||||
if not self._cursor._has_next() and self._cursor._post_batch_resume_token:
|
||||
resume_token = self._cursor._post_batch_resume_token
|
||||
|
||||
# Hereafter, don't use startAfter; instead use resumeAfter.
|
||||
self._uses_start_after = False
|
||||
self._uses_resume_after = True
|
||||
|
||||
# Cache the resume token and clear startAtOperationTime.
|
||||
self._resume_token = resume_token
|
||||
self._start_at_operation_time = None
|
||||
|
||||
if self._decode_custom:
|
||||
return _bson_to_dict(change.raw, self._orig_codec_options)
|
||||
return change
|
||||
|
||||
async def __aenter__(self) -> AsyncChangeStream[_DocumentType]:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
|
||||
class AsyncCollectionChangeStream(AsyncChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on a single collection.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
helper method :meth:`pymongo.asynchronous.collection.AsyncCollection.watch` instead.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
|
||||
_target: AsyncCollection[_DocumentType]
|
||||
|
||||
@property
|
||||
def _aggregation_command_class(self) -> Type[_CollectionAggregationCommand]:
|
||||
return _CollectionAggregationCommand
|
||||
|
||||
@property
|
||||
def _client(self) -> AsyncMongoClient[_DocumentType]:
|
||||
return self._target.database.client
|
||||
|
||||
|
||||
class AsyncDatabaseChangeStream(AsyncChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on all collections in a database.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
helper method :meth:`pymongo.asynchronous.database.AsyncDatabase.watch` instead.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
|
||||
_target: AsyncDatabase[_DocumentType]
|
||||
|
||||
@property
|
||||
def _aggregation_command_class(self) -> Type[_DatabaseAggregationCommand]:
|
||||
return _DatabaseAggregationCommand
|
||||
|
||||
@property
|
||||
def _client(self) -> AsyncMongoClient[_DocumentType]:
|
||||
return self._target.client
|
||||
|
||||
|
||||
class AsyncClusterChangeStream(AsyncDatabaseChangeStream[_DocumentType]):
|
||||
"""A change stream that watches changes on all collections in the cluster.
|
||||
|
||||
Should not be called directly by application developers. Use
|
||||
helper method :meth:`pymongo.asynchronous.mongo_client.AsyncMongoClient.watch` instead.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
"""
|
||||
|
||||
def _change_stream_options(self) -> dict[str, Any]:
|
||||
options = super()._change_stream_options()
|
||||
options["allChangesForCluster"] = True
|
||||
return options
|
||||
@@ -0,0 +1,800 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The client-level bulk write operations interface.
|
||||
|
||||
.. versionadded:: 4.9
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
import logging
|
||||
from collections.abc import MutableMapping
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo import _csot, common
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.helpers import _handle_reauth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo._client_bulk_shared import (
|
||||
_merge_command,
|
||||
_throw_client_bulk_write_exception,
|
||||
)
|
||||
from pymongo.common import (
|
||||
validate_is_document_type,
|
||||
validate_ok_for_replace,
|
||||
validate_ok_for_update,
|
||||
)
|
||||
from pymongo.errors import (
|
||||
ConfigurationError,
|
||||
ConnectionFailure,
|
||||
InvalidOperation,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
WaitQueueTimeoutError,
|
||||
)
|
||||
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
|
||||
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.message import (
|
||||
_ClientBulkWriteContext,
|
||||
_convert_client_bulk_exception,
|
||||
_convert_exception,
|
||||
_convert_write_result,
|
||||
_randint,
|
||||
)
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.results import (
|
||||
ClientBulkWriteResult,
|
||||
DeleteResult,
|
||||
InsertOneResult,
|
||||
UpdateResult,
|
||||
)
|
||||
from pymongo.typings import _DocumentOut, _Pipeline
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class _AsyncClientBulk:
|
||||
"""The private guts of the client-level bulk write API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: AsyncMongoClient,
|
||||
write_concern: WriteConcern,
|
||||
ordered: bool = True,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
comment: Optional[str] = None,
|
||||
let: Optional[Any] = None,
|
||||
verbose_results: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a _AsyncClientBulk instance."""
|
||||
self.client = client
|
||||
self.write_concern = write_concern
|
||||
self.let = let
|
||||
if self.let is not None:
|
||||
common.validate_is_document_type("let", self.let)
|
||||
self.ordered = ordered
|
||||
self.bypass_doc_val = bypass_document_validation
|
||||
self.comment = comment
|
||||
self.verbose_results = verbose_results
|
||||
|
||||
self.ops: list[tuple[str, Mapping[str, Any]]] = []
|
||||
self.namespaces: list[str] = []
|
||||
self.idx_offset: int = 0
|
||||
self.total_ops: int = 0
|
||||
|
||||
self.executed = False
|
||||
self.uses_upsert = False
|
||||
self.uses_collation = False
|
||||
self.uses_array_filters = False
|
||||
self.uses_hint_update = False
|
||||
self.uses_hint_delete = False
|
||||
|
||||
self.is_retryable = self.client.options.retry_writes
|
||||
self.retrying = False
|
||||
self.started_retryable_write = False
|
||||
|
||||
@property
|
||||
def bulk_ctx_class(self) -> Type[_ClientBulkWriteContext]:
|
||||
return _ClientBulkWriteContext
|
||||
|
||||
def add_insert(self, namespace: str, document: _DocumentOut) -> None:
|
||||
"""Add an insert document to the list of ops."""
|
||||
validate_is_document_type("document", document)
|
||||
# Generate ObjectId client side.
|
||||
if not (isinstance(document, RawBSONDocument) or "_id" in document):
|
||||
document["_id"] = ObjectId()
|
||||
cmd = {"insert": -1, "document": document}
|
||||
self.ops.append(("insert", cmd))
|
||||
self.namespaces.append(namespace)
|
||||
self.total_ops += 1
|
||||
|
||||
def add_update(
|
||||
self,
|
||||
namespace: str,
|
||||
selector: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
multi: bool = False,
|
||||
upsert: Optional[bool] = None,
|
||||
collation: Optional[Mapping[str, Any]] = None,
|
||||
array_filters: Optional[list[Mapping[str, Any]]] = None,
|
||||
hint: Union[str, dict[str, Any], None] = None,
|
||||
) -> None:
|
||||
"""Create an update document and add it to the list of ops."""
|
||||
validate_ok_for_update(update)
|
||||
cmd = {
|
||||
"update": -1,
|
||||
"filter": selector,
|
||||
"updateMods": update,
|
||||
"multi": multi,
|
||||
}
|
||||
if upsert is not None:
|
||||
self.uses_upsert = True
|
||||
cmd["upsert"] = upsert
|
||||
if array_filters is not None:
|
||||
self.uses_array_filters = True
|
||||
cmd["arrayFilters"] = array_filters
|
||||
if hint is not None:
|
||||
self.uses_hint_update = True
|
||||
cmd["hint"] = hint
|
||||
if collation is not None:
|
||||
self.uses_collation = True
|
||||
cmd["collation"] = collation
|
||||
if multi:
|
||||
# A bulk_write containing an update_many is not retryable.
|
||||
self.is_retryable = False
|
||||
self.ops.append(("update", cmd))
|
||||
self.namespaces.append(namespace)
|
||||
self.total_ops += 1
|
||||
|
||||
def add_replace(
|
||||
self,
|
||||
namespace: str,
|
||||
selector: Mapping[str, Any],
|
||||
replacement: Mapping[str, Any],
|
||||
upsert: Optional[bool] = None,
|
||||
collation: Optional[Mapping[str, Any]] = None,
|
||||
hint: Union[str, dict[str, Any], None] = None,
|
||||
) -> None:
|
||||
"""Create a replace document and add it to the list of ops."""
|
||||
validate_ok_for_replace(replacement)
|
||||
cmd = {
|
||||
"update": -1,
|
||||
"filter": selector,
|
||||
"updateMods": replacement,
|
||||
"multi": False,
|
||||
}
|
||||
if upsert is not None:
|
||||
self.uses_upsert = True
|
||||
cmd["upsert"] = upsert
|
||||
if hint is not None:
|
||||
self.uses_hint_update = True
|
||||
cmd["hint"] = hint
|
||||
if collation is not None:
|
||||
self.uses_collation = True
|
||||
cmd["collation"] = collation
|
||||
self.ops.append(("replace", cmd))
|
||||
self.namespaces.append(namespace)
|
||||
self.total_ops += 1
|
||||
|
||||
def add_delete(
|
||||
self,
|
||||
namespace: str,
|
||||
selector: Mapping[str, Any],
|
||||
multi: bool,
|
||||
collation: Optional[Mapping[str, Any]] = None,
|
||||
hint: Union[str, dict[str, Any], None] = None,
|
||||
) -> None:
|
||||
"""Create a delete document and add it to the list of ops."""
|
||||
cmd = {"delete": -1, "filter": selector, "multi": multi}
|
||||
if hint is not None:
|
||||
self.uses_hint_delete = True
|
||||
cmd["hint"] = hint
|
||||
if collation is not None:
|
||||
self.uses_collation = True
|
||||
cmd["collation"] = collation
|
||||
if multi:
|
||||
# A bulk_write containing an update_many is not retryable.
|
||||
self.is_retryable = False
|
||||
self.ops.append(("delete", cmd))
|
||||
self.namespaces.append(namespace)
|
||||
self.total_ops += 1
|
||||
|
||||
@_handle_reauth
|
||||
async def write_command(
|
||||
self,
|
||||
bwc: _ClientBulkWriteContext,
|
||||
cmd: MutableMapping[str, Any],
|
||||
request_id: int,
|
||||
msg: Union[bytes, dict[str, Any]],
|
||||
op_docs: list[Mapping[str, Any]],
|
||||
ns_docs: list[Mapping[str, Any]],
|
||||
client: AsyncMongoClient,
|
||||
) -> dict[str, Any]:
|
||||
"""A proxy for AsyncConnection.write_command that handles event publishing."""
|
||||
cmd["ops"] = op_docs
|
||||
cmd["nsInfo"] = ns_docs
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
command=cmd,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._start(cmd, request_id, op_docs, ns_docs)
|
||||
try:
|
||||
reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type]
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
durationMS=duration,
|
||||
reply=reply,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
|
||||
# Process the response from the server.
|
||||
await self.client._process_response(reply, bwc.session) # type: ignore[arg-type]
|
||||
except Exception as exc:
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if isinstance(exc, (NotPrimaryError, OperationFailure)):
|
||||
failure: _DocumentOut = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
|
||||
if bwc.publish:
|
||||
bwc._fail(request_id, failure, duration)
|
||||
# Top-level error will be embedded in ClientBulkWriteException.
|
||||
reply = {"error": exc}
|
||||
# Process the response from the server.
|
||||
if isinstance(exc, OperationFailure):
|
||||
await self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
|
||||
else:
|
||||
await self.client._process_response({}, bwc.session) # type: ignore[arg-type]
|
||||
return reply # type: ignore[return-value]
|
||||
|
||||
async def unack_write(
|
||||
self,
|
||||
bwc: _ClientBulkWriteContext,
|
||||
cmd: MutableMapping[str, Any],
|
||||
request_id: int,
|
||||
msg: bytes,
|
||||
op_docs: list[Mapping[str, Any]],
|
||||
ns_docs: list[Mapping[str, Any]],
|
||||
client: AsyncMongoClient,
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""A proxy for AsyncConnection.unack_write that handles event publishing."""
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
command=cmd,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
cmd = bwc._start(cmd, request_id, op_docs, ns_docs)
|
||||
try:
|
||||
result = await bwc.conn.unack_write(msg, bwc.max_bson_size) # type: ignore[func-returns-value, misc, override]
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if result is not None:
|
||||
reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type]
|
||||
else:
|
||||
# Comply with APM spec.
|
||||
reply = {"ok": 1}
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
durationMS=duration,
|
||||
reply=reply,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
)
|
||||
if bwc.publish:
|
||||
bwc._succeed(request_id, reply, duration)
|
||||
except Exception as exc:
|
||||
duration = datetime.datetime.now() - bwc.start_time
|
||||
if isinstance(exc, OperationFailure):
|
||||
failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type]
|
||||
elif isinstance(exc, NotPrimaryError):
|
||||
failure = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=bwc.db_name,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=bwc.conn.id,
|
||||
serverConnectionId=bwc.conn.server_connection_id,
|
||||
serverHost=bwc.conn.address[0],
|
||||
serverPort=bwc.conn.address[1],
|
||||
serviceId=bwc.conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
if bwc.publish:
|
||||
assert bwc.start_time is not None
|
||||
bwc._fail(request_id, failure, duration)
|
||||
# Top-level error will be embedded in ClientBulkWriteException.
|
||||
reply = {"error": exc}
|
||||
return reply
|
||||
|
||||
async def _execute_batch_unack(
|
||||
self,
|
||||
bwc: _ClientBulkWriteContext,
|
||||
cmd: dict[str, Any],
|
||||
ops: list[tuple[str, Mapping[str, Any]]],
|
||||
namespaces: list[str],
|
||||
) -> tuple[list[Mapping[str, Any]], list[Mapping[str, Any]]]:
|
||||
"""Executes a batch of bulkWrite server commands (unack)."""
|
||||
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces)
|
||||
await self.unack_write(bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client) # type: ignore[arg-type]
|
||||
return to_send_ops, to_send_ns
|
||||
|
||||
async def _execute_batch(
|
||||
self,
|
||||
bwc: _ClientBulkWriteContext,
|
||||
cmd: dict[str, Any],
|
||||
ops: list[tuple[str, Mapping[str, Any]]],
|
||||
namespaces: list[str],
|
||||
) -> tuple[dict[str, Any], list[Mapping[str, Any]], list[Mapping[str, Any]]]:
|
||||
"""Executes a batch of bulkWrite server commands (ack)."""
|
||||
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces)
|
||||
result = await self.write_command(
|
||||
bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client
|
||||
) # type: ignore[arg-type]
|
||||
return result, to_send_ops, to_send_ns # type: ignore[return-value]
|
||||
|
||||
async def _process_results_cursor(
|
||||
self,
|
||||
full_result: MutableMapping[str, Any],
|
||||
result: MutableMapping[str, Any],
|
||||
conn: AsyncConnection,
|
||||
session: Optional[AsyncClientSession],
|
||||
) -> None:
|
||||
"""Internal helper for processing the server reply command cursor."""
|
||||
if result.get("cursor"):
|
||||
coll = AsyncCollection(
|
||||
database=AsyncDatabase(self.client, "admin"),
|
||||
name="$cmd.bulkWrite",
|
||||
)
|
||||
cmd_cursor = AsyncCommandCursor(
|
||||
coll,
|
||||
result["cursor"],
|
||||
conn.address,
|
||||
session=session,
|
||||
explicit_session=session is not None,
|
||||
comment=self.comment,
|
||||
)
|
||||
await cmd_cursor._maybe_pin_connection(conn)
|
||||
|
||||
# Iterate the cursor to get individual write results.
|
||||
try:
|
||||
async for doc in cmd_cursor:
|
||||
original_index = doc["idx"] + self.idx_offset
|
||||
op_type, op = self.ops[original_index]
|
||||
|
||||
if not doc["ok"]:
|
||||
result["writeErrors"].append(doc)
|
||||
if self.ordered:
|
||||
return
|
||||
|
||||
# Record individual write result.
|
||||
if doc["ok"] and self.verbose_results:
|
||||
if op_type == "insert":
|
||||
inserted_id = op["document"]["_id"]
|
||||
res = InsertOneResult(inserted_id, acknowledged=True) # type: ignore[assignment]
|
||||
if op_type in ["update", "replace"]:
|
||||
op_type = "update"
|
||||
res = UpdateResult(doc, acknowledged=True, in_client_bulk=True) # type: ignore[assignment]
|
||||
if op_type == "delete":
|
||||
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
|
||||
full_result[f"{op_type}Results"][original_index] = res
|
||||
|
||||
except Exception as exc:
|
||||
# Attempt to close the cursor, then raise top-level error.
|
||||
if cmd_cursor.alive:
|
||||
await cmd_cursor.close()
|
||||
result["error"] = _convert_client_bulk_exception(exc)
|
||||
|
||||
async def _execute_command(
|
||||
self,
|
||||
write_concern: WriteConcern,
|
||||
session: Optional[AsyncClientSession],
|
||||
conn: AsyncConnection,
|
||||
op_id: int,
|
||||
retryable: bool,
|
||||
full_result: MutableMapping[str, Any],
|
||||
final_write_concern: Optional[WriteConcern] = None,
|
||||
) -> None:
|
||||
"""Internal helper for executing batches of bulkWrite commands."""
|
||||
db_name = "admin"
|
||||
cmd_name = "bulkWrite"
|
||||
listeners = self.client._event_listeners
|
||||
|
||||
# AsyncConnection.command validates the session, but we use
|
||||
# AsyncConnection.write_command
|
||||
conn.validate_session(self.client, session)
|
||||
|
||||
bwc = self.bulk_ctx_class(
|
||||
db_name,
|
||||
cmd_name,
|
||||
conn,
|
||||
op_id,
|
||||
listeners, # type: ignore[arg-type]
|
||||
session,
|
||||
self.client.codec_options,
|
||||
)
|
||||
|
||||
while self.idx_offset < self.total_ops:
|
||||
# If this is the last possible batch, use the
|
||||
# final write concern.
|
||||
if self.total_ops - self.idx_offset <= bwc.max_write_batch_size:
|
||||
write_concern = final_write_concern or write_concern
|
||||
|
||||
# Construct the server command, specifying the relevant options.
|
||||
cmd = {"bulkWrite": 1}
|
||||
cmd["errorsOnly"] = not self.verbose_results
|
||||
cmd["ordered"] = self.ordered # type: ignore[assignment]
|
||||
not_in_transaction = session and not session.in_transaction
|
||||
if not_in_transaction or not session:
|
||||
_csot.apply_write_concern(cmd, write_concern)
|
||||
if self.bypass_doc_val is not None:
|
||||
cmd["bypassDocumentValidation"] = self.bypass_doc_val
|
||||
if self.comment:
|
||||
cmd["comment"] = self.comment # type: ignore[assignment]
|
||||
if self.let:
|
||||
cmd["let"] = self.let
|
||||
|
||||
if session:
|
||||
# Start a new retryable write unless one was already
|
||||
# started for this command.
|
||||
if retryable and not self.started_retryable_write:
|
||||
session._start_retryable_write()
|
||||
self.started_retryable_write = True
|
||||
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
|
||||
conn.send_cluster_time(cmd, session, self.client)
|
||||
conn.add_server_api(cmd)
|
||||
# CSOT: apply timeout before encoding the command.
|
||||
conn.apply_timeout(self.client, cmd)
|
||||
ops = islice(self.ops, self.idx_offset, None)
|
||||
namespaces = islice(self.namespaces, self.idx_offset, None)
|
||||
|
||||
# Run as many ops as possible in one server command.
|
||||
if write_concern.acknowledged:
|
||||
raw_result, to_send_ops, _ = await self._execute_batch(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
|
||||
result = raw_result
|
||||
|
||||
# Top-level server/network error.
|
||||
if result.get("error"):
|
||||
error = result["error"]
|
||||
retryable_top_level_error = (
|
||||
hasattr(error, "details")
|
||||
and isinstance(error.details, dict)
|
||||
and error.details.get("code", 0) in _RETRYABLE_ERROR_CODES
|
||||
)
|
||||
retryable_network_error = isinstance(
|
||||
error, ConnectionFailure
|
||||
) and not isinstance(error, (NotPrimaryError, WaitQueueTimeoutError))
|
||||
|
||||
# Synthesize the full bulk result without modifying the
|
||||
# current one because this write operation may be retried.
|
||||
if retryable and (retryable_top_level_error or retryable_network_error):
|
||||
full = copy.deepcopy(full_result)
|
||||
_merge_command(self.ops, self.idx_offset, full, result)
|
||||
_throw_client_bulk_write_exception(full, self.verbose_results)
|
||||
else:
|
||||
_merge_command(self.ops, self.idx_offset, full_result, result)
|
||||
_throw_client_bulk_write_exception(full_result, self.verbose_results)
|
||||
|
||||
result["error"] = None
|
||||
result["writeErrors"] = []
|
||||
if result.get("nErrors", 0) < len(to_send_ops):
|
||||
full_result["anySuccessful"] = True
|
||||
|
||||
# Top-level command error.
|
||||
if not result["ok"]:
|
||||
result["error"] = raw_result
|
||||
_merge_command(self.ops, self.idx_offset, full_result, result)
|
||||
break
|
||||
|
||||
if retryable:
|
||||
# Retryable writeConcernErrors halt the execution of this batch.
|
||||
wce = result.get("writeConcernError", {})
|
||||
if wce.get("code", 0) in _RETRYABLE_ERROR_CODES:
|
||||
# Synthesize the full bulk result without modifying the
|
||||
# current one because this write operation may be retried.
|
||||
full = copy.deepcopy(full_result)
|
||||
_merge_command(self.ops, self.idx_offset, full, result)
|
||||
_throw_client_bulk_write_exception(full, self.verbose_results)
|
||||
|
||||
# Process the server reply as a command cursor.
|
||||
await self._process_results_cursor(full_result, result, conn, session)
|
||||
|
||||
# Merge this batch's results with the full results.
|
||||
_merge_command(self.ops, self.idx_offset, full_result, result)
|
||||
|
||||
# We're no longer in a retry once a command succeeds.
|
||||
self.retrying = False
|
||||
self.started_retryable_write = False
|
||||
|
||||
else:
|
||||
to_send_ops, _ = await self._execute_batch_unack(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
|
||||
|
||||
self.idx_offset += len(to_send_ops)
|
||||
|
||||
# We halt execution if we hit a top-level error,
|
||||
# or an individual error in an ordered bulk write.
|
||||
if full_result["error"] or (self.ordered and full_result["writeErrors"]):
|
||||
break
|
||||
|
||||
async def execute_command(
|
||||
self,
|
||||
session: Optional[AsyncClientSession],
|
||||
operation: str,
|
||||
) -> MutableMapping[str, Any]:
|
||||
"""Execute commands with w=1 WriteConcern."""
|
||||
full_result: MutableMapping[str, Any] = {
|
||||
"anySuccessful": False,
|
||||
"error": None,
|
||||
"writeErrors": [],
|
||||
"writeConcernErrors": [],
|
||||
"nInserted": 0,
|
||||
"nUpserted": 0,
|
||||
"nMatched": 0,
|
||||
"nModified": 0,
|
||||
"nDeleted": 0,
|
||||
"insertResults": {},
|
||||
"updateResults": {},
|
||||
"deleteResults": {},
|
||||
}
|
||||
op_id = _randint()
|
||||
|
||||
async def retryable_bulk(
|
||||
session: Optional[AsyncClientSession],
|
||||
conn: AsyncConnection,
|
||||
retryable: bool,
|
||||
) -> None:
|
||||
if conn.max_wire_version < 25:
|
||||
raise InvalidOperation(
|
||||
"MongoClient.bulk_write requires MongoDB server version 8.0+."
|
||||
)
|
||||
await self._execute_command(
|
||||
self.write_concern,
|
||||
session,
|
||||
conn,
|
||||
op_id,
|
||||
retryable,
|
||||
full_result,
|
||||
)
|
||||
|
||||
await self.client._retryable_write(
|
||||
self.is_retryable,
|
||||
retryable_bulk,
|
||||
session,
|
||||
operation,
|
||||
bulk=self,
|
||||
operation_id=op_id,
|
||||
)
|
||||
|
||||
if full_result["error"] or full_result["writeErrors"] or full_result["writeConcernErrors"]:
|
||||
_throw_client_bulk_write_exception(full_result, self.verbose_results)
|
||||
return full_result
|
||||
|
||||
async def execute_command_unack_unordered(
|
||||
self,
|
||||
conn: AsyncConnection,
|
||||
) -> None:
|
||||
"""Execute commands with OP_MSG and w=0 writeConcern, unordered."""
|
||||
db_name = "admin"
|
||||
cmd_name = "bulkWrite"
|
||||
listeners = self.client._event_listeners
|
||||
op_id = _randint()
|
||||
|
||||
bwc = self.bulk_ctx_class(
|
||||
db_name,
|
||||
cmd_name,
|
||||
conn,
|
||||
op_id,
|
||||
listeners, # type: ignore[arg-type]
|
||||
None,
|
||||
self.client.codec_options,
|
||||
)
|
||||
|
||||
while self.idx_offset < self.total_ops:
|
||||
# Construct the server command, specifying the relevant options.
|
||||
cmd = {"bulkWrite": 1}
|
||||
cmd["errorsOnly"] = not self.verbose_results
|
||||
cmd["ordered"] = self.ordered # type: ignore[assignment]
|
||||
if self.bypass_doc_val is not None:
|
||||
cmd["bypassDocumentValidation"] = self.bypass_doc_val
|
||||
cmd["writeConcern"] = {"w": 0} # type: ignore[assignment]
|
||||
if self.comment:
|
||||
cmd["comment"] = self.comment # type: ignore[assignment]
|
||||
if self.let:
|
||||
cmd["let"] = self.let
|
||||
|
||||
conn.add_server_api(cmd)
|
||||
ops = islice(self.ops, self.idx_offset, None)
|
||||
namespaces = islice(self.namespaces, self.idx_offset, None)
|
||||
|
||||
# Run as many ops as possible in one server command.
|
||||
to_send_ops, _ = await self._execute_batch_unack(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
|
||||
|
||||
self.idx_offset += len(to_send_ops)
|
||||
|
||||
async def execute_command_unack_ordered(
|
||||
self,
|
||||
conn: AsyncConnection,
|
||||
) -> None:
|
||||
"""Execute commands with OP_MSG and w=0 WriteConcern, ordered."""
|
||||
full_result: MutableMapping[str, Any] = {
|
||||
"anySuccessful": False,
|
||||
"error": None,
|
||||
"writeErrors": [],
|
||||
"writeConcernErrors": [],
|
||||
"nInserted": 0,
|
||||
"nUpserted": 0,
|
||||
"nMatched": 0,
|
||||
"nModified": 0,
|
||||
"nDeleted": 0,
|
||||
"insertResults": {},
|
||||
"updateResults": {},
|
||||
"deleteResults": {},
|
||||
}
|
||||
# Ordered bulk writes have to be acknowledged so that we stop
|
||||
# processing at the first error, even when the application
|
||||
# specified unacknowledged writeConcern.
|
||||
initial_write_concern = WriteConcern()
|
||||
op_id = _randint()
|
||||
try:
|
||||
await self._execute_command(
|
||||
initial_write_concern,
|
||||
None,
|
||||
conn,
|
||||
op_id,
|
||||
False,
|
||||
full_result,
|
||||
self.write_concern,
|
||||
)
|
||||
except OperationFailure:
|
||||
pass
|
||||
|
||||
async def execute_no_results(
|
||||
self,
|
||||
conn: AsyncConnection,
|
||||
) -> None:
|
||||
"""Execute all operations, returning no results (w=0)."""
|
||||
if self.uses_collation:
|
||||
raise ConfigurationError("Collation is unsupported for unacknowledged writes.")
|
||||
if self.uses_array_filters:
|
||||
raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.")
|
||||
# Cannot have both unacknowledged writes and bypass document validation.
|
||||
if self.bypass_doc_val is not None:
|
||||
raise OperationFailure(
|
||||
"Cannot set bypass_document_validation with unacknowledged write concern"
|
||||
)
|
||||
|
||||
if self.ordered:
|
||||
return await self.execute_command_unack_ordered(conn)
|
||||
return await self.execute_command_unack_unordered(conn)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
session: Optional[AsyncClientSession],
|
||||
operation: str,
|
||||
) -> Any:
|
||||
"""Execute operations."""
|
||||
if not self.ops:
|
||||
raise InvalidOperation("No operations to execute")
|
||||
if self.executed:
|
||||
raise InvalidOperation("Bulk operations can only be executed once.")
|
||||
self.executed = True
|
||||
session = _validate_session_write_concern(session, self.write_concern)
|
||||
|
||||
if not self.write_concern.acknowledged:
|
||||
async with await self.client._conn_for_writes(session, operation) as connection:
|
||||
if connection.max_wire_version < 25:
|
||||
raise InvalidOperation(
|
||||
"MongoClient.bulk_write requires MongoDB server version 8.0+."
|
||||
)
|
||||
await self.execute_no_results(connection)
|
||||
return ClientBulkWriteResult(None, False, False) # type: ignore[arg-type]
|
||||
|
||||
result = await self.execute_command(session, operation)
|
||||
return ClientBulkWriteResult(
|
||||
result,
|
||||
self.write_concern.acknowledged,
|
||||
self.verbose_results,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
3563
venv/lib/python3.11/site-packages/pymongo/asynchronous/collection.py
Normal file
3563
venv/lib/python3.11/site-packages/pymongo/asynchronous/collection.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,470 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CommandCursor class to iterate over command results."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Generic,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson import CodecOptions, _convert_raw_document_lists_to_streams
|
||||
from pymongo import _csot
|
||||
from pymongo.asynchronous.cursor import _ConnectionManager
|
||||
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
from pymongo.message import (
|
||||
_CursorAddress,
|
||||
_GetMore,
|
||||
_OpMsg,
|
||||
_OpReply,
|
||||
_RawBatchGetMore,
|
||||
)
|
||||
from pymongo.response import PinnedResponse
|
||||
from pymongo.typings import _Address, _DocumentOut, _DocumentType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class AsyncCommandCursor(Generic[_DocumentType]):
|
||||
"""An asynchronous cursor / iterator over command cursors."""
|
||||
|
||||
_getmore_class = _GetMore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: AsyncCollection[_DocumentType],
|
||||
cursor_info: Mapping[str, Any],
|
||||
address: Optional[_Address],
|
||||
batch_size: int = 0,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
explicit_session: bool = False,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
"""Create a new command cursor."""
|
||||
self._sock_mgr: Any = None
|
||||
self._collection: AsyncCollection[_DocumentType] = collection
|
||||
self._id = cursor_info["id"]
|
||||
self._data = deque(cursor_info["firstBatch"])
|
||||
self._postbatchresumetoken: Optional[Mapping[str, Any]] = cursor_info.get(
|
||||
"postBatchResumeToken"
|
||||
)
|
||||
self._address = address
|
||||
self._batch_size = batch_size
|
||||
self._max_await_time_ms = max_await_time_ms
|
||||
self._timeout = self._collection.database.client.options.timeout
|
||||
self._session = session
|
||||
self._explicit_session = explicit_session
|
||||
self._killed = self._id == 0
|
||||
self._comment = comment
|
||||
if self._killed:
|
||||
self._end_session()
|
||||
|
||||
if "ns" in cursor_info: # noqa: SIM401
|
||||
self._ns = cursor_info["ns"]
|
||||
else:
|
||||
self._ns = collection.full_name
|
||||
|
||||
self.batch_size(batch_size)
|
||||
|
||||
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
|
||||
raise TypeError("max_await_time_ms must be an integer or None")
|
||||
|
||||
def __del__(self) -> None:
|
||||
self._die_no_lock()
|
||||
|
||||
def batch_size(self, batch_size: int) -> AsyncCommandCursor[_DocumentType]:
|
||||
"""Limits the number of documents returned in one batch. Each batch
|
||||
requires a round trip to the server. It can be adjusted to optimize
|
||||
performance and limit data transfer.
|
||||
|
||||
.. note:: batch_size can not override MongoDB's internal limits on the
|
||||
amount of data it will return to the client in a single batch (i.e
|
||||
if you set batch size to 1,000,000,000, MongoDB will currently only
|
||||
return 4-16MB of results per batch).
|
||||
|
||||
Raises :exc:`TypeError` if `batch_size` is not an integer.
|
||||
Raises :exc:`ValueError` if `batch_size` is less than ``0``.
|
||||
|
||||
:param batch_size: The size of each batch of results requested.
|
||||
"""
|
||||
if not isinstance(batch_size, int):
|
||||
raise TypeError("batch_size must be an integer")
|
||||
if batch_size < 0:
|
||||
raise ValueError("batch_size must be >= 0")
|
||||
|
||||
self._batch_size = batch_size == 1 and 2 or batch_size
|
||||
return self
|
||||
|
||||
def _has_next(self) -> bool:
|
||||
"""Returns `True` if the cursor has documents remaining from the
|
||||
previous batch.
|
||||
"""
|
||||
return len(self._data) > 0
|
||||
|
||||
@property
|
||||
def _post_batch_resume_token(self) -> Optional[Mapping[str, Any]]:
|
||||
"""Retrieve the postBatchResumeToken from the response to a
|
||||
changeStream aggregate or getMore.
|
||||
"""
|
||||
return self._postbatchresumetoken
|
||||
|
||||
async def _maybe_pin_connection(self, conn: AsyncConnection) -> None:
|
||||
client = self._collection.database.client
|
||||
if not client._should_pin_cursor(self._session):
|
||||
return
|
||||
if not self._sock_mgr:
|
||||
conn.pin_cursor()
|
||||
conn_mgr = _ConnectionManager(conn, False)
|
||||
# Ensure the connection gets returned when the entire result is
|
||||
# returned in the first batch.
|
||||
if self._id == 0:
|
||||
await conn_mgr.close()
|
||||
else:
|
||||
self._sock_mgr = conn_mgr
|
||||
|
||||
def _unpack_response(
|
||||
self,
|
||||
response: Union[_OpReply, _OpMsg],
|
||||
cursor_id: Optional[int],
|
||||
codec_options: CodecOptions[Mapping[str, Any]],
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> Sequence[_DocumentOut]:
|
||||
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
"""Does this cursor have the potential to return more data?
|
||||
|
||||
Even if :attr:`alive` is ``True``, :meth:`next` can raise
|
||||
:exc:`StopIteration`. Best to use a for loop::
|
||||
|
||||
async for doc in collection.aggregate(pipeline):
|
||||
print(doc)
|
||||
|
||||
.. note:: :attr:`alive` can be True while iterating a cursor from
|
||||
a failed server. In this case :attr:`alive` will return False after
|
||||
:meth:`next` fails to retrieve the next batch of results from the
|
||||
server.
|
||||
"""
|
||||
return bool(len(self._data) or (not self._killed))
|
||||
|
||||
@property
|
||||
def cursor_id(self) -> int:
|
||||
"""Returns the id of the cursor."""
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def address(self) -> Optional[_Address]:
|
||||
"""The (host, port) of the server used, or None.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
return self._address
|
||||
|
||||
@property
|
||||
def session(self) -> Optional[AsyncClientSession]:
|
||||
"""The cursor's :class:`~pymongo.asynchronous.client_session.AsyncClientSession`, or None.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
"""
|
||||
if self._explicit_session:
|
||||
return self._session
|
||||
return None
|
||||
|
||||
def _prepare_to_die(self) -> tuple[int, Optional[_CursorAddress]]:
|
||||
already_killed = self._killed
|
||||
self._killed = True
|
||||
if self._id and not already_killed:
|
||||
cursor_id = self._id
|
||||
assert self._address is not None
|
||||
address = _CursorAddress(self._address, self._ns)
|
||||
else:
|
||||
# Skip killCursors.
|
||||
cursor_id = 0
|
||||
address = None
|
||||
return cursor_id, address
|
||||
|
||||
def _die_no_lock(self) -> None:
|
||||
"""Closes this cursor without acquiring a lock."""
|
||||
cursor_id, address = self._prepare_to_die()
|
||||
self._collection.database.client._cleanup_cursor_no_lock(
|
||||
cursor_id, address, self._sock_mgr, self._session, self._explicit_session
|
||||
)
|
||||
if not self._explicit_session:
|
||||
self._session = None
|
||||
self._sock_mgr = None
|
||||
|
||||
async def _die_lock(self) -> None:
|
||||
"""Closes this cursor."""
|
||||
cursor_id, address = self._prepare_to_die()
|
||||
await self._collection.database.client._cleanup_cursor_lock(
|
||||
cursor_id,
|
||||
address,
|
||||
self._sock_mgr,
|
||||
self._session,
|
||||
self._explicit_session,
|
||||
)
|
||||
if not self._explicit_session:
|
||||
self._session = None
|
||||
self._sock_mgr = None
|
||||
|
||||
def _end_session(self) -> None:
|
||||
if self._session and not self._explicit_session:
|
||||
self._session._end_implicit_session()
|
||||
self._session = None
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Explicitly close / kill this cursor."""
|
||||
await self._die_lock()
|
||||
|
||||
async def _send_message(self, operation: _GetMore) -> None:
|
||||
"""Send a getmore message and handle the response."""
|
||||
client = self._collection.database.client
|
||||
try:
|
||||
response = await client._run_operation(
|
||||
operation, self._unpack_response, address=self._address
|
||||
)
|
||||
except OperationFailure as exc:
|
||||
if exc.code in _CURSOR_CLOSED_ERRORS:
|
||||
# Don't send killCursors because the cursor is already closed.
|
||||
self._killed = True
|
||||
if exc.timeout:
|
||||
self._die_no_lock()
|
||||
else:
|
||||
# Return the session and pinned connection, if necessary.
|
||||
await self.close()
|
||||
raise
|
||||
except ConnectionFailure:
|
||||
# Don't send killCursors because the cursor is already closed.
|
||||
self._killed = True
|
||||
# Return the session and pinned connection, if necessary.
|
||||
await self.close()
|
||||
raise
|
||||
except Exception:
|
||||
await self.close()
|
||||
raise
|
||||
|
||||
if isinstance(response, PinnedResponse):
|
||||
if not self._sock_mgr:
|
||||
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type]
|
||||
if response.from_command:
|
||||
cursor = response.docs[0]["cursor"]
|
||||
documents = cursor["nextBatch"]
|
||||
self._postbatchresumetoken = cursor.get("postBatchResumeToken")
|
||||
self._id = cursor["id"]
|
||||
else:
|
||||
documents = response.docs
|
||||
assert isinstance(response.data, _OpReply)
|
||||
self._id = response.data.cursor_id
|
||||
|
||||
if self._id == 0:
|
||||
await self.close()
|
||||
self._data = deque(documents)
|
||||
|
||||
async def _refresh(self) -> int:
|
||||
"""Refreshes the cursor with more data from the server.
|
||||
|
||||
Returns the length of self._data after refresh. Will exit early if
|
||||
self._data is already non-empty. Raises OperationFailure when the
|
||||
cursor cannot be refreshed due to an error on the query.
|
||||
"""
|
||||
if len(self._data) or self._killed:
|
||||
return len(self._data)
|
||||
|
||||
if self._id: # Get More
|
||||
dbname, collname = self._ns.split(".", 1)
|
||||
read_pref = self._collection._read_preference_for(self.session)
|
||||
await self._send_message(
|
||||
self._getmore_class(
|
||||
dbname,
|
||||
collname,
|
||||
self._batch_size,
|
||||
self._id,
|
||||
self._collection.codec_options,
|
||||
read_pref,
|
||||
self._session,
|
||||
self._collection.database.client,
|
||||
self._max_await_time_ms,
|
||||
self._sock_mgr,
|
||||
False,
|
||||
self._comment,
|
||||
)
|
||||
)
|
||||
else: # Cursor id is zero nothing else to return
|
||||
await self._die_lock()
|
||||
|
||||
return len(self._data)
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[_DocumentType]:
|
||||
return self
|
||||
|
||||
async def next(self) -> _DocumentType:
|
||||
"""Advance the cursor."""
|
||||
# Block until a document is returnable.
|
||||
while self.alive:
|
||||
doc = await self._try_next(True)
|
||||
if doc is not None:
|
||||
return doc
|
||||
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def __anext__(self) -> _DocumentType:
|
||||
return await self.next()
|
||||
|
||||
async def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]:
|
||||
"""Advance the cursor blocking for at most one getMore command."""
|
||||
if not len(self._data) and not self._killed and get_more_allowed:
|
||||
await self._refresh()
|
||||
if len(self._data):
|
||||
return self._data.popleft()
|
||||
else:
|
||||
return None
|
||||
|
||||
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
|
||||
"""Get all or some available documents from the cursor."""
|
||||
if not len(self._data) and not self._killed:
|
||||
await self._refresh()
|
||||
if len(self._data):
|
||||
if total is None:
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
else:
|
||||
for _ in range(min(len(self._data), total)):
|
||||
result.append(self._data.popleft())
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def try_next(self) -> Optional[_DocumentType]:
|
||||
"""Advance the cursor without blocking indefinitely.
|
||||
|
||||
This method returns the next document without waiting
|
||||
indefinitely for data.
|
||||
|
||||
If no document is cached locally then this method runs a single
|
||||
getMore command. If the getMore yields any documents, the next
|
||||
document is returned, otherwise, if the getMore returns no documents
|
||||
(because there is no additional data) then ``None`` is returned.
|
||||
|
||||
:return: The next document or ``None`` when no document is available
|
||||
after running a single getMore or when the cursor is closed.
|
||||
|
||||
.. versionadded:: 4.5
|
||||
"""
|
||||
return await self._try_next(get_more_allowed=True)
|
||||
|
||||
async def __aenter__(self) -> AsyncCommandCursor[_DocumentType]:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
@_csot.apply
|
||||
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
|
||||
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
|
||||
|
||||
To use::
|
||||
|
||||
>>> await cursor.to_list()
|
||||
|
||||
Or, so read at most n items from the cursor::
|
||||
|
||||
>>> await cursor.to_list(n)
|
||||
|
||||
If the cursor is empty or has no more results, an empty list will be returned.
|
||||
|
||||
.. versionadded:: 4.9
|
||||
"""
|
||||
res: list[_DocumentType] = []
|
||||
remaining = length
|
||||
if isinstance(length, int) and length < 1:
|
||||
raise ValueError("to_list() length must be greater than 0")
|
||||
while self.alive:
|
||||
if not await self._next_batch(res, remaining):
|
||||
break
|
||||
if length is not None:
|
||||
remaining = length - len(res)
|
||||
if remaining == 0:
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]):
|
||||
_getmore_class = _RawBatchGetMore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: AsyncCollection[_DocumentType],
|
||||
cursor_info: Mapping[str, Any],
|
||||
address: Optional[_Address],
|
||||
batch_size: int = 0,
|
||||
max_await_time_ms: Optional[int] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
explicit_session: bool = False,
|
||||
comment: Any = None,
|
||||
) -> None:
|
||||
"""Create a new cursor / iterator over raw batches of BSON data.
|
||||
|
||||
Should not be called directly by application developers -
|
||||
see :meth:`~pymongo.asynchronous.collection.AsyncCollection.aggregate_raw_batches`
|
||||
instead.
|
||||
|
||||
.. seealso:: The MongoDB documentation on `cursors <https://dochub.mongodb.org/core/cursors>`_.
|
||||
"""
|
||||
assert not cursor_info.get("firstBatch")
|
||||
super().__init__(
|
||||
collection,
|
||||
cursor_info,
|
||||
address,
|
||||
batch_size,
|
||||
max_await_time_ms,
|
||||
session,
|
||||
explicit_session,
|
||||
comment,
|
||||
)
|
||||
|
||||
def _unpack_response( # type: ignore[override]
|
||||
self,
|
||||
response: Union[_OpReply, _OpMsg],
|
||||
cursor_id: Optional[int],
|
||||
codec_options: CodecOptions,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> list[Mapping[str, Any]]:
|
||||
raw_response = response.raw_response(cursor_id, user_fields=user_fields)
|
||||
if not legacy_response:
|
||||
# OP_MSG returns firstBatch/nextBatch documents as a BSON array
|
||||
# Re-assemble the array of documents into a document stream
|
||||
_convert_raw_document_lists_to_streams(raw_response[0])
|
||||
return raw_response # type: ignore[return-value]
|
||||
|
||||
def __getitem__(self, index: int) -> NoReturn:
|
||||
raise InvalidOperation("Cannot call __getitem__ on AsyncRawBatchCommandCursor")
|
||||
1367
venv/lib/python3.11/site-packages/pymongo/asynchronous/cursor.py
Normal file
1367
venv/lib/python3.11/site-packages/pymongo/asynchronous/cursor.py
Normal file
File diff suppressed because it is too large
Load Diff
1432
venv/lib/python3.11/site-packages/pymongo/asynchronous/database.py
Normal file
1432
venv/lib/python3.11/site-packages/pymongo/asynchronous/database.py
Normal file
File diff suppressed because it is too large
Load Diff
1169
venv/lib/python3.11/site-packages/pymongo/asynchronous/encryption.py
Normal file
1169
venv/lib/python3.11/site-packages/pymongo/asynchronous/encryption.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,82 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Miscellaneous pieces that need to be synchronized."""
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import sys
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pymongo.errors import (
|
||||
OperationFailure,
|
||||
)
|
||||
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def _handle_reauth(func: F) -> F:
|
||||
async def inner(*args: Any, **kwargs: Any) -> Any:
|
||||
no_reauth = kwargs.pop("no_reauth", False)
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.message import _BulkWriteContext
|
||||
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except OperationFailure as exc:
|
||||
if no_reauth:
|
||||
raise
|
||||
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
|
||||
# Look for an argument that either is a AsyncConnection
|
||||
# or has a connection attribute, so we can trigger
|
||||
# a reauth.
|
||||
conn = None
|
||||
for arg in args:
|
||||
if isinstance(arg, AsyncConnection):
|
||||
conn = arg
|
||||
break
|
||||
if isinstance(arg, _BulkWriteContext):
|
||||
conn = arg.conn # type: ignore[assignment]
|
||||
break
|
||||
if conn:
|
||||
await conn.authenticate(reauthenticate=True)
|
||||
else:
|
||||
raise
|
||||
return func(*args, **kwargs)
|
||||
raise
|
||||
|
||||
return cast(F, inner)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
anext = builtins.anext
|
||||
aiter = builtins.aiter
|
||||
else:
|
||||
|
||||
async def anext(cls: Any) -> Any:
|
||||
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext."""
|
||||
return await cls.__anext__()
|
||||
|
||||
def aiter(cls: Any) -> Any:
|
||||
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext."""
|
||||
return cls.__aiter__()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,534 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Class to monitor a MongoDB server on a background thread."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
import time
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
|
||||
|
||||
from pymongo import common
|
||||
from pymongo._csot import MovingMinimum
|
||||
from pymongo.asynchronous import periodic_executor
|
||||
from pymongo.asynchronous.periodic_executor import _shutdown_executors
|
||||
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
|
||||
from pymongo.pool_options import _is_faas
|
||||
from pymongo.read_preferences import MovingAverage
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.srv_resolver import _SrvResolver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext
|
||||
from pymongo.asynchronous.settings import TopologySettings
|
||||
from pymongo.asynchronous.topology import Topology
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
def _sanitize(error: Exception) -> None:
|
||||
"""PYTHON-2433 Clear error traceback info."""
|
||||
error.__traceback__ = None
|
||||
error.__context__ = None
|
||||
error.__cause__ = None
|
||||
|
||||
|
||||
def _monotonic_duration(start: float) -> float:
|
||||
"""Return the duration since the given start time.
|
||||
|
||||
Accounts for buggy platforms where time.monotonic() is not monotonic.
|
||||
See PYTHON-4600.
|
||||
"""
|
||||
return max(0.0, time.monotonic() - start)
|
||||
|
||||
|
||||
class MonitorBase:
|
||||
def __init__(self, topology: Topology, name: str, interval: int, min_interval: float):
|
||||
"""Base class to do periodic work on a background thread.
|
||||
|
||||
The background thread is signaled to stop when the Topology or
|
||||
this instance is freed.
|
||||
"""
|
||||
|
||||
# We strongly reference the executor and it weakly references us via
|
||||
# this closure. When the monitor is freed, stop the executor soon.
|
||||
async def target() -> bool:
|
||||
monitor = self_ref()
|
||||
if monitor is None:
|
||||
return False # Stop the executor.
|
||||
await monitor._run() # type:ignore[attr-defined]
|
||||
return True
|
||||
|
||||
executor = periodic_executor.PeriodicExecutor(
|
||||
interval=interval, min_interval=min_interval, target=target, name=name
|
||||
)
|
||||
|
||||
self._executor = executor
|
||||
|
||||
def _on_topology_gc(dummy: Optional[Topology] = None) -> None:
|
||||
# This prevents GC from waiting 10 seconds for hello to complete
|
||||
# See test_cleanup_executors_on_client_del.
|
||||
monitor = self_ref()
|
||||
if monitor:
|
||||
monitor.gc_safe_close()
|
||||
|
||||
# Avoid cycles. When self or topology is freed, stop executor soon.
|
||||
self_ref = weakref.ref(self, executor.close)
|
||||
self._topology = weakref.proxy(topology, _on_topology_gc)
|
||||
_register(self)
|
||||
|
||||
def open(self) -> None:
|
||||
"""Start monitoring, or restart after a fork.
|
||||
|
||||
Multiple calls have no effect.
|
||||
"""
|
||||
self._executor.open()
|
||||
|
||||
def gc_safe_close(self) -> None:
|
||||
"""GC safe close."""
|
||||
self._executor.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close and stop monitoring.
|
||||
|
||||
open() restarts the monitor after closing.
|
||||
"""
|
||||
self.gc_safe_close()
|
||||
|
||||
def join(self, timeout: Optional[int] = None) -> None:
|
||||
"""Wait for the monitor to stop."""
|
||||
self._executor.join(timeout)
|
||||
|
||||
def request_check(self) -> None:
|
||||
"""If the monitor is sleeping, wake it soon."""
|
||||
self._executor.wake()
|
||||
|
||||
|
||||
class Monitor(MonitorBase):
|
||||
def __init__(
|
||||
self,
|
||||
server_description: ServerDescription,
|
||||
topology: Topology,
|
||||
pool: Pool,
|
||||
topology_settings: TopologySettings,
|
||||
):
|
||||
"""Class to monitor a MongoDB server on a background thread.
|
||||
|
||||
Pass an initial ServerDescription, a Topology, a Pool, and
|
||||
TopologySettings.
|
||||
|
||||
The Topology is weakly referenced. The Pool must be exclusive to this
|
||||
Monitor.
|
||||
"""
|
||||
super().__init__(
|
||||
topology,
|
||||
"pymongo_server_monitor_thread",
|
||||
topology_settings.heartbeat_frequency,
|
||||
common.MIN_HEARTBEAT_INTERVAL,
|
||||
)
|
||||
self._server_description = server_description
|
||||
self._pool = pool
|
||||
self._settings = topology_settings
|
||||
self._listeners = self._settings._pool_options._event_listeners
|
||||
self._publish = self._listeners is not None and self._listeners.enabled_for_server_heartbeat
|
||||
self._cancel_context: Optional[_CancellationContext] = None
|
||||
self._rtt_monitor = _RttMonitor(
|
||||
topology,
|
||||
topology_settings,
|
||||
topology._create_pool_for_monitor(server_description.address),
|
||||
)
|
||||
if topology_settings.server_monitoring_mode == "stream":
|
||||
self._stream = True
|
||||
elif topology_settings.server_monitoring_mode == "poll":
|
||||
self._stream = False
|
||||
else:
|
||||
self._stream = not _is_faas()
|
||||
|
||||
def cancel_check(self) -> None:
|
||||
"""Cancel any concurrent hello check.
|
||||
|
||||
Note: this is called from a weakref.proxy callback and MUST NOT take
|
||||
any locks.
|
||||
"""
|
||||
context = self._cancel_context
|
||||
if context:
|
||||
# Note: we cannot close the socket because doing so may cause
|
||||
# concurrent reads/writes to hang until a timeout occurs
|
||||
# (depending on the platform).
|
||||
context.cancel()
|
||||
|
||||
async def _start_rtt_monitor(self) -> None:
|
||||
"""Start an _RttMonitor that periodically runs ping."""
|
||||
# If this monitor is closed directly before (or during) this open()
|
||||
# call, the _RttMonitor will not be closed. Checking if this monitor
|
||||
# was closed directly after resolves the race.
|
||||
self._rtt_monitor.open()
|
||||
if self._executor._stopped:
|
||||
await self._rtt_monitor.close()
|
||||
|
||||
def gc_safe_close(self) -> None:
|
||||
self._executor.close()
|
||||
self._rtt_monitor.gc_safe_close()
|
||||
self.cancel_check()
|
||||
|
||||
async def close(self) -> None:
|
||||
self.gc_safe_close()
|
||||
await self._rtt_monitor.close()
|
||||
# Increment the generation and maybe close the socket. If the executor
|
||||
# thread has the socket checked out, it will be closed when checked in.
|
||||
await self._reset_connection()
|
||||
|
||||
async def _reset_connection(self) -> None:
|
||||
# Clear our pooled connection.
|
||||
await self._pool.reset()
|
||||
|
||||
async def _run(self) -> None:
|
||||
try:
|
||||
prev_sd = self._server_description
|
||||
try:
|
||||
self._server_description = await self._check_server()
|
||||
except _OperationCancelled as exc:
|
||||
_sanitize(exc)
|
||||
# Already closed the connection, wait for the next check.
|
||||
self._server_description = ServerDescription(
|
||||
self._server_description.address, error=exc
|
||||
)
|
||||
if prev_sd.is_server_type_known:
|
||||
# Immediately retry since we've already waited 500ms to
|
||||
# discover that we've been cancelled.
|
||||
self._executor.skip_sleep()
|
||||
return
|
||||
|
||||
# Update the Topology and clear the server pool on error.
|
||||
await self._topology.on_change(
|
||||
self._server_description,
|
||||
reset_pool=self._server_description.error,
|
||||
interrupt_connections=isinstance(self._server_description.error, NetworkTimeout),
|
||||
)
|
||||
|
||||
if self._stream and (
|
||||
self._server_description.is_server_type_known
|
||||
and self._server_description.topology_version
|
||||
):
|
||||
await self._start_rtt_monitor()
|
||||
# Immediately check for the next streaming response.
|
||||
self._executor.skip_sleep()
|
||||
|
||||
if self._server_description.error and prev_sd.is_server_type_known:
|
||||
# Immediately retry on network errors.
|
||||
self._executor.skip_sleep()
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
await self.close()
|
||||
|
||||
async def _check_server(self) -> ServerDescription:
|
||||
"""Call hello or read the next streaming response.
|
||||
|
||||
Returns a ServerDescription.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
try:
|
||||
try:
|
||||
return await self._check_once()
|
||||
except (OperationFailure, NotPrimaryError) as exc:
|
||||
# Update max cluster time even when hello fails.
|
||||
details = cast(Mapping[str, Any], exc.details)
|
||||
self._topology.receive_cluster_time(details.get("$clusterTime"))
|
||||
raise
|
||||
except ReferenceError:
|
||||
raise
|
||||
except Exception as error:
|
||||
_sanitize(error)
|
||||
sd = self._server_description
|
||||
address = sd.address
|
||||
duration = _monotonic_duration(start)
|
||||
awaited = bool(self._stream and sd.is_server_type_known and sd.topology_version)
|
||||
if self._publish:
|
||||
assert self._listeners is not None
|
||||
self._listeners.publish_server_heartbeat_failed(address, duration, error, awaited)
|
||||
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_SDAM_LOGGER,
|
||||
topologyId=self._topology._topology_id,
|
||||
serverHost=address[0],
|
||||
serverPort=address[1],
|
||||
awaited=awaited,
|
||||
durationMS=duration * 1000,
|
||||
failure=error,
|
||||
message=_SDAMStatusMessage.HEARTBEAT_FAIL,
|
||||
)
|
||||
await self._reset_connection()
|
||||
if isinstance(error, _OperationCancelled):
|
||||
raise
|
||||
self._rtt_monitor.reset()
|
||||
# Server type defaults to Unknown.
|
||||
return ServerDescription(address, error=error)
|
||||
|
||||
async def _check_once(self) -> ServerDescription:
|
||||
"""A single attempt to call hello.
|
||||
|
||||
Returns a ServerDescription, or raises an exception.
|
||||
"""
|
||||
address = self._server_description.address
|
||||
sd = self._server_description
|
||||
|
||||
# XXX: "awaited" could be incorrectly set to True in the rare case
|
||||
# the pool checkout closes and recreates a connection.
|
||||
awaited = bool(
|
||||
self._pool.conns and self._stream and sd.is_server_type_known and sd.topology_version
|
||||
)
|
||||
if self._publish:
|
||||
assert self._listeners is not None
|
||||
self._listeners.publish_server_heartbeat_started(address, awaited)
|
||||
|
||||
if self._cancel_context and self._cancel_context.cancelled:
|
||||
await self._reset_connection()
|
||||
async with self._pool.checkout() as conn:
|
||||
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_SDAM_LOGGER,
|
||||
topologyId=self._topology._topology_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=address[0],
|
||||
serverPort=address[1],
|
||||
awaited=awaited,
|
||||
message=_SDAMStatusMessage.HEARTBEAT_START,
|
||||
)
|
||||
|
||||
self._cancel_context = conn.cancel_context
|
||||
response, round_trip_time = await self._check_with_socket(conn)
|
||||
if not response.awaitable:
|
||||
self._rtt_monitor.add_sample(round_trip_time)
|
||||
|
||||
avg_rtt, min_rtt = self._rtt_monitor.get()
|
||||
sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt)
|
||||
if self._publish:
|
||||
assert self._listeners is not None
|
||||
self._listeners.publish_server_heartbeat_succeeded(
|
||||
address, round_trip_time, response, response.awaitable
|
||||
)
|
||||
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_SDAM_LOGGER,
|
||||
topologyId=self._topology._topology_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=address[0],
|
||||
serverPort=address[1],
|
||||
awaited=awaited,
|
||||
durationMS=round_trip_time * 1000,
|
||||
reply=response.document,
|
||||
message=_SDAMStatusMessage.HEARTBEAT_SUCCESS,
|
||||
)
|
||||
return sd
|
||||
|
||||
async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float]:
|
||||
"""Return (Hello, round_trip_time).
|
||||
|
||||
Can raise ConnectionFailure or OperationFailure.
|
||||
"""
|
||||
cluster_time = self._topology.max_cluster_time()
|
||||
start = time.monotonic()
|
||||
if conn.more_to_come:
|
||||
# Read the next streaming hello (MongoDB 4.4+).
|
||||
response = Hello(await conn._next_reply(), awaitable=True)
|
||||
elif (
|
||||
self._stream and conn.performed_handshake and self._server_description.topology_version
|
||||
):
|
||||
# Initiate streaming hello (MongoDB 4.4+).
|
||||
response = await conn._hello(
|
||||
cluster_time,
|
||||
self._server_description.topology_version,
|
||||
self._settings.heartbeat_frequency,
|
||||
)
|
||||
else:
|
||||
# New connection handshake or polling hello (MongoDB <4.4).
|
||||
response = await conn._hello(cluster_time, None, None)
|
||||
duration = _monotonic_duration(start)
|
||||
return response, duration
|
||||
|
||||
|
||||
class SrvMonitor(MonitorBase):
|
||||
def __init__(self, topology: Topology, topology_settings: TopologySettings):
|
||||
"""Class to poll SRV records on a background thread.
|
||||
|
||||
Pass a Topology and a TopologySettings.
|
||||
|
||||
The Topology is weakly referenced.
|
||||
"""
|
||||
super().__init__(
|
||||
topology,
|
||||
"pymongo_srv_polling_thread",
|
||||
common.MIN_SRV_RESCAN_INTERVAL,
|
||||
topology_settings.heartbeat_frequency,
|
||||
)
|
||||
self._settings = topology_settings
|
||||
self._seedlist = self._settings._seeds
|
||||
assert isinstance(self._settings.fqdn, str)
|
||||
self._fqdn: str = self._settings.fqdn
|
||||
self._startup_time = time.monotonic()
|
||||
|
||||
async def _run(self) -> None:
|
||||
# Don't poll right after creation, wait 60 seconds first
|
||||
if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL:
|
||||
return
|
||||
seedlist = self._get_seedlist()
|
||||
if seedlist:
|
||||
self._seedlist = seedlist
|
||||
try:
|
||||
await self._topology.on_srv_update(self._seedlist)
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
await self.close()
|
||||
|
||||
def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
|
||||
"""Poll SRV records for a seedlist.
|
||||
|
||||
Returns a list of ServerDescriptions.
|
||||
"""
|
||||
try:
|
||||
resolver = _SrvResolver(
|
||||
self._fqdn,
|
||||
self._settings.pool_options.connect_timeout,
|
||||
self._settings.srv_service_name,
|
||||
)
|
||||
seedlist, ttl = resolver.get_hosts_and_min_ttl()
|
||||
if len(seedlist) == 0:
|
||||
# As per the spec: this should be treated as a failure.
|
||||
raise Exception
|
||||
except Exception:
|
||||
# As per the spec, upon encountering an error:
|
||||
# - An error must not be raised
|
||||
# - SRV records must be rescanned every heartbeatFrequencyMS
|
||||
# - Topology must be left unchanged
|
||||
self.request_check()
|
||||
return None
|
||||
else:
|
||||
self._executor.update_interval(max(ttl, common.MIN_SRV_RESCAN_INTERVAL))
|
||||
return seedlist
|
||||
|
||||
|
||||
class _RttMonitor(MonitorBase):
|
||||
def __init__(self, topology: Topology, topology_settings: TopologySettings, pool: Pool):
|
||||
"""Maintain round trip times for a server.
|
||||
|
||||
The Topology is weakly referenced.
|
||||
"""
|
||||
super().__init__(
|
||||
topology,
|
||||
"pymongo_server_rtt_thread",
|
||||
topology_settings.heartbeat_frequency,
|
||||
common.MIN_HEARTBEAT_INTERVAL,
|
||||
)
|
||||
|
||||
self._pool = pool
|
||||
self._moving_average = MovingAverage()
|
||||
self._moving_min = MovingMinimum()
|
||||
self._lock = _create_lock()
|
||||
|
||||
async def close(self) -> None:
|
||||
self.gc_safe_close()
|
||||
# Increment the generation and maybe close the socket. If the executor
|
||||
# thread has the socket checked out, it will be closed when checked in.
|
||||
await self._pool.reset()
|
||||
|
||||
def add_sample(self, sample: float) -> None:
|
||||
"""Add a RTT sample."""
|
||||
with self._lock:
|
||||
self._moving_average.add_sample(sample)
|
||||
self._moving_min.add_sample(sample)
|
||||
|
||||
def get(self) -> tuple[Optional[float], float]:
|
||||
"""Get the calculated average, or None if no samples yet and the min."""
|
||||
with self._lock:
|
||||
return self._moving_average.get(), self._moving_min.get()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the average RTT."""
|
||||
with self._lock:
|
||||
self._moving_average.reset()
|
||||
self._moving_min.reset()
|
||||
|
||||
async def _run(self) -> None:
|
||||
try:
|
||||
# NOTE: This thread is only run when using the streaming
|
||||
# heartbeat protocol (MongoDB 4.4+).
|
||||
# XXX: Skip check if the server is unknown?
|
||||
rtt = await self._ping()
|
||||
self.add_sample(rtt)
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
await self.close()
|
||||
except Exception:
|
||||
await self._pool.reset()
|
||||
|
||||
async def _ping(self) -> float:
|
||||
"""Run a "hello" command and return the RTT."""
|
||||
async with self._pool.checkout() as conn:
|
||||
if self._executor._stopped:
|
||||
raise Exception("_RttMonitor closed")
|
||||
start = time.monotonic()
|
||||
await conn.hello()
|
||||
return _monotonic_duration(start)
|
||||
|
||||
|
||||
# Close monitors to cancel any in progress streaming checks before joining
|
||||
# executor threads. For an explanation of how this works see the comment
|
||||
# about _EXECUTORS in periodic_executor.py.
|
||||
_MONITORS = set()
|
||||
|
||||
|
||||
def _register(monitor: MonitorBase) -> None:
|
||||
ref = weakref.ref(monitor, _unregister)
|
||||
_MONITORS.add(ref)
|
||||
|
||||
|
||||
def _unregister(monitor_ref: weakref.ReferenceType[MonitorBase]) -> None:
|
||||
_MONITORS.remove(monitor_ref)
|
||||
|
||||
|
||||
def _shutdown_monitors() -> None:
|
||||
if _MONITORS is None:
|
||||
return
|
||||
|
||||
# Copy the set. Closing monitors removes them.
|
||||
monitors = list(_MONITORS)
|
||||
|
||||
# Close all monitors.
|
||||
for ref in monitors:
|
||||
monitor = ref()
|
||||
if monitor:
|
||||
monitor.gc_safe_close()
|
||||
|
||||
monitor = None
|
||||
|
||||
|
||||
def _shutdown_resources() -> None:
|
||||
# _shutdown_monitors/_shutdown_executors may already be GC'd at shutdown.
|
||||
shutdown = _shutdown_monitors
|
||||
if shutdown: # type:ignore[truthy-function]
|
||||
shutdown()
|
||||
shutdown = _shutdown_executors
|
||||
if shutdown: # type:ignore[truthy-function]
|
||||
shutdown()
|
||||
|
||||
|
||||
atexit.register(_shutdown_resources)
|
||||
@@ -0,0 +1,414 @@
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Internal network layer helper methods."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import errno
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from bson import _decode_all_selective
|
||||
from pymongo import _csot, helpers_shared, message
|
||||
from pymongo.common import MAX_MESSAGE_SIZE
|
||||
from pymongo.compression_support import _NO_COMPRESSION, decompress
|
||||
from pymongo.errors import (
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
ProtocolError,
|
||||
_OperationCancelled,
|
||||
)
|
||||
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
|
||||
from pymongo.monitoring import _is_speculative_authenticate
|
||||
from pymongo.network_layer import (
|
||||
_POLL_TIMEOUT,
|
||||
_UNPACK_COMPRESSION_HEADER,
|
||||
_UNPACK_HEADER,
|
||||
BLOCKING_IO_ERRORS,
|
||||
async_sendall,
|
||||
)
|
||||
from pymongo.socket_checker import _errno_from_exception
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson import CodecOptions
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
|
||||
from pymongo.monitoring import _EventListeners
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
async def command(
|
||||
conn: AsyncConnection,
|
||||
dbname: str,
|
||||
spec: MutableMapping[str, Any],
|
||||
is_mongos: bool,
|
||||
read_preference: Optional[_ServerMode],
|
||||
codec_options: CodecOptions[_DocumentType],
|
||||
session: Optional[AsyncClientSession],
|
||||
client: Optional[AsyncMongoClient],
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
address: Optional[_Address] = None,
|
||||
listeners: Optional[_EventListeners] = None,
|
||||
max_bson_size: Optional[int] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
parse_write_concern_error: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
|
||||
use_op_msg: bool = False,
|
||||
unacknowledged: bool = False,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
exhaust_allowed: bool = False,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
) -> _DocumentType:
|
||||
"""Execute a command over the socket, or raise socket.error.
|
||||
|
||||
:param conn: a AsyncConnection instance
|
||||
:param dbname: name of the database on which to run the command
|
||||
:param spec: a command document as an ordered dict type, eg SON.
|
||||
:param is_mongos: are we connected to a mongos?
|
||||
:param read_preference: a read preference
|
||||
:param codec_options: a CodecOptions instance
|
||||
:param session: optional AsyncClientSession instance.
|
||||
:param client: optional AsyncMongoClient instance for updating $clusterTime.
|
||||
:param check: raise OperationFailure if there are errors
|
||||
:param allowable_errors: errors to ignore if `check` is True
|
||||
:param address: the (host, port) of `conn`
|
||||
:param listeners: An instance of :class:`~pymongo.monitoring.EventListeners`
|
||||
:param max_bson_size: The maximum encoded bson size for this server
|
||||
:param read_concern: The read concern for this command.
|
||||
:param parse_write_concern_error: Whether to parse the ``writeConcernError``
|
||||
field in the command response.
|
||||
:param collation: The collation for this command.
|
||||
:param compression_ctx: optional compression Context.
|
||||
:param use_op_msg: True if we should use OP_MSG.
|
||||
:param unacknowledged: True if this is an unacknowledged command.
|
||||
:param user_fields: Response fields that should be decoded
|
||||
using the TypeDecoders from codec_options, passed to
|
||||
bson._decode_all_selective.
|
||||
:param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed.
|
||||
"""
|
||||
name = next(iter(spec))
|
||||
ns = dbname + ".$cmd"
|
||||
speculative_hello = False
|
||||
|
||||
# Publish the original command document, perhaps with lsid and $clusterTime.
|
||||
orig = spec
|
||||
if is_mongos and not use_op_msg:
|
||||
assert read_preference is not None
|
||||
spec = message._maybe_add_read_preference(spec, read_preference)
|
||||
if read_concern and not (session and session.in_transaction):
|
||||
if read_concern.level:
|
||||
spec["readConcern"] = read_concern.document
|
||||
if session:
|
||||
session._update_read_concern(spec, conn)
|
||||
if collation is not None:
|
||||
spec["collation"] = collation
|
||||
|
||||
publish = listeners is not None and listeners.enabled_for_commands
|
||||
start = datetime.datetime.now()
|
||||
if publish:
|
||||
speculative_hello = _is_speculative_authenticate(name, spec)
|
||||
|
||||
if compression_ctx and name.lower() in _NO_COMPRESSION:
|
||||
compression_ctx = None
|
||||
|
||||
if client and client._encrypter and not client._encrypter._bypass_auto_encryption:
|
||||
spec = orig = await client._encrypter.encrypt(dbname, spec, codec_options)
|
||||
|
||||
# Support CSOT
|
||||
if client:
|
||||
conn.apply_timeout(client, spec)
|
||||
_csot.apply_write_concern(spec, write_concern)
|
||||
|
||||
if use_op_msg:
|
||||
flags = _OpMsg.MORE_TO_COME if unacknowledged else 0
|
||||
flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0
|
||||
request_id, msg, size, max_doc_size = message._op_msg(
|
||||
flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx
|
||||
)
|
||||
# If this is an unacknowledged write then make sure the encoded doc(s)
|
||||
# are small enough, otherwise rely on the server to return an error.
|
||||
if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size:
|
||||
message._raise_document_too_large(name, size, max_bson_size)
|
||||
else:
|
||||
request_id, msg, size = message._query(
|
||||
0, ns, 0, -1, spec, None, codec_options, compression_ctx
|
||||
)
|
||||
|
||||
if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD:
|
||||
message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD)
|
||||
if client is not None:
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
command=spec,
|
||||
commandName=next(iter(spec)),
|
||||
databaseName=dbname,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
)
|
||||
if publish:
|
||||
assert listeners is not None
|
||||
assert address is not None
|
||||
listeners.publish_command_start(
|
||||
orig,
|
||||
dbname,
|
||||
request_id,
|
||||
address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
)
|
||||
|
||||
try:
|
||||
await async_sendall(conn.conn, msg)
|
||||
if use_op_msg and unacknowledged:
|
||||
# Unacknowledged, fake a successful command response.
|
||||
reply = None
|
||||
response_doc: _DocumentOut = {"ok": 1}
|
||||
else:
|
||||
reply = await receive_message(conn, request_id)
|
||||
conn.more_to_come = reply.more_to_come
|
||||
unpacked_docs = reply.unpack_response(
|
||||
codec_options=codec_options, user_fields=user_fields
|
||||
)
|
||||
|
||||
response_doc = unpacked_docs[0]
|
||||
if client:
|
||||
await client._process_response(response_doc, session)
|
||||
if check:
|
||||
helpers_shared._check_command_response(
|
||||
response_doc,
|
||||
conn.max_wire_version,
|
||||
allowable_errors,
|
||||
parse_write_concern_error=parse_write_concern_error,
|
||||
)
|
||||
except Exception as exc:
|
||||
duration = datetime.datetime.now() - start
|
||||
if isinstance(exc, (NotPrimaryError, OperationFailure)):
|
||||
failure: _DocumentOut = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = message._convert_exception(exc)
|
||||
if client is not None:
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(spec)),
|
||||
databaseName=dbname,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
if publish:
|
||||
assert listeners is not None
|
||||
assert address is not None
|
||||
listeners.publish_command_failure(
|
||||
duration,
|
||||
failure,
|
||||
name,
|
||||
request_id,
|
||||
address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
database_name=dbname,
|
||||
)
|
||||
raise
|
||||
duration = datetime.datetime.now() - start
|
||||
if client is not None:
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
durationMS=duration,
|
||||
reply=response_doc,
|
||||
commandName=next(iter(spec)),
|
||||
databaseName=dbname,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
speculative_authenticate="speculativeAuthenticate" in orig,
|
||||
)
|
||||
if publish:
|
||||
assert listeners is not None
|
||||
assert address is not None
|
||||
listeners.publish_command_success(
|
||||
duration,
|
||||
response_doc,
|
||||
name,
|
||||
request_id,
|
||||
address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
speculative_hello=speculative_hello,
|
||||
database_name=dbname,
|
||||
)
|
||||
|
||||
if client and client._encrypter and reply:
|
||||
decrypted = await client._encrypter.decrypt(reply.raw_command_response())
|
||||
response_doc = cast(
|
||||
"_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0]
|
||||
)
|
||||
|
||||
return response_doc # type: ignore[return-value]
|
||||
|
||||
|
||||
async def receive_message(
|
||||
conn: AsyncConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
|
||||
) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise socket.error."""
|
||||
if _csot.get_timeout():
|
||||
deadline = _csot.get_deadline()
|
||||
else:
|
||||
timeout = conn.conn.gettimeout()
|
||||
if timeout:
|
||||
deadline = time.monotonic() + timeout
|
||||
else:
|
||||
deadline = None
|
||||
# Ignore the response's request id.
|
||||
length, _, response_to, op_code = _UNPACK_HEADER(
|
||||
await _receive_data_on_socket(conn, 16, deadline)
|
||||
)
|
||||
# No request_id for exhaust cursor "getMore".
|
||||
if request_id is not None:
|
||||
if request_id != response_to:
|
||||
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
|
||||
if length <= 16:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) not longer than standard message header size (16)"
|
||||
)
|
||||
if length > max_message_size:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) is larger than server max "
|
||||
f"message size ({max_message_size!r})"
|
||||
)
|
||||
if op_code == 2012:
|
||||
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
|
||||
await _receive_data_on_socket(conn, 9, deadline)
|
||||
)
|
||||
data = decompress(await _receive_data_on_socket(conn, length - 25, deadline), compressor_id)
|
||||
else:
|
||||
data = await _receive_data_on_socket(conn, length - 16, deadline)
|
||||
|
||||
try:
|
||||
unpack_reply = _UNPACK_REPLY[op_code]
|
||||
except KeyError:
|
||||
raise ProtocolError(
|
||||
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
|
||||
) from None
|
||||
return unpack_reply(data)
|
||||
|
||||
|
||||
async def wait_for_read(conn: AsyncConnection, deadline: Optional[float]) -> None:
|
||||
"""Block until at least one byte is read, or a timeout, or a cancel."""
|
||||
sock = conn.conn
|
||||
timed_out = False
|
||||
# Check if the connection's socket has been manually closed
|
||||
if sock.fileno() == -1:
|
||||
return
|
||||
while True:
|
||||
# SSLSocket can have buffered data which won't be caught by select.
|
||||
if hasattr(sock, "pending") and sock.pending() > 0:
|
||||
readable = True
|
||||
else:
|
||||
# Wait up to 500ms for the socket to become readable and then
|
||||
# check for cancellation.
|
||||
if deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
# When the timeout has expired perform one final check to
|
||||
# see if the socket is readable. This helps avoid spurious
|
||||
# timeouts on AWS Lambda and other FaaS environments.
|
||||
if remaining <= 0:
|
||||
timed_out = True
|
||||
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
|
||||
else:
|
||||
timeout = _POLL_TIMEOUT
|
||||
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
|
||||
if conn.cancel_context.cancelled:
|
||||
raise _OperationCancelled("operation cancelled")
|
||||
if readable:
|
||||
return
|
||||
if timed_out:
|
||||
raise socket.timeout("timed out")
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
async def _receive_data_on_socket(
|
||||
conn: AsyncConnection, length: int, deadline: Optional[float]
|
||||
) -> memoryview:
|
||||
buf = bytearray(length)
|
||||
mv = memoryview(buf)
|
||||
bytes_read = 0
|
||||
while bytes_read < length:
|
||||
try:
|
||||
await wait_for_read(conn, deadline)
|
||||
# CSOT: Update timeout. When the timeout has expired perform one
|
||||
# final non-blocking recv. This helps avoid spurious timeouts when
|
||||
# the response is actually already buffered on the client.
|
||||
if _csot.get_timeout() and deadline is not None:
|
||||
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
|
||||
chunk_length = conn.conn.recv_into(mv[bytes_read:])
|
||||
except BLOCKING_IO_ERRORS:
|
||||
raise socket.timeout("timed out") from None
|
||||
except OSError as exc:
|
||||
if _errno_from_exception(exc) == errno.EINTR:
|
||||
continue
|
||||
raise
|
||||
if chunk_length == 0:
|
||||
raise OSError("connection closed")
|
||||
|
||||
bytes_read += chunk_length
|
||||
|
||||
return mv
|
||||
@@ -0,0 +1,219 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Run a target function on a background thread."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Optional
|
||||
|
||||
from pymongo.lock import _ALock, _create_lock
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class PeriodicExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
interval: float,
|
||||
min_interval: float,
|
||||
target: Any,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Run a target function periodically on a background thread.
|
||||
|
||||
If the target's return value is false, the executor stops.
|
||||
|
||||
:param interval: Seconds between calls to `target`.
|
||||
:param min_interval: Minimum seconds between calls if `wake` is
|
||||
called very often.
|
||||
:param target: A function.
|
||||
:param name: A name to give the underlying thread.
|
||||
"""
|
||||
# threading.Event and its internal condition variable are expensive
|
||||
# in Python 2, see PYTHON-983. Use a boolean to know when to wake.
|
||||
# The executor's design is constrained by several Python issues, see
|
||||
# "periodic_executor.rst" in this repository.
|
||||
self._event = False
|
||||
self._interval = interval
|
||||
self._min_interval = min_interval
|
||||
self._target = target
|
||||
self._stopped = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._name = name
|
||||
self._skip_sleep = False
|
||||
self._thread_will_exit = False
|
||||
self._lock = _ALock(_create_lock())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
|
||||
|
||||
def _run_async(self) -> None:
|
||||
# The default asyncio loop implementation on Windows
|
||||
# has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240)
|
||||
# We explicitly use a different loop implementation here to prevent that issue
|
||||
if sys.platform == "win32":
|
||||
loop = asyncio.SelectorEventLoop()
|
||||
try:
|
||||
loop.run_until_complete(self._run()) # type: ignore[func-returns-value]
|
||||
finally:
|
||||
loop.close()
|
||||
else:
|
||||
asyncio.run(self._run()) # type: ignore[func-returns-value]
|
||||
|
||||
def open(self) -> None:
|
||||
"""Start. Multiple calls have no effect.
|
||||
|
||||
Not safe to call from multiple threads at once.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._thread_will_exit:
|
||||
# If the background thread has read self._stopped as True
|
||||
# there is a chance that it has not yet exited. The call to
|
||||
# join should not block indefinitely because there is no
|
||||
# other work done outside the while loop in self._run.
|
||||
try:
|
||||
assert self._thread is not None
|
||||
self._thread.join()
|
||||
except ReferenceError:
|
||||
# Thread terminated.
|
||||
pass
|
||||
self._thread_will_exit = False
|
||||
self._stopped = False
|
||||
started: Any = False
|
||||
try:
|
||||
started = self._thread and self._thread.is_alive()
|
||||
except ReferenceError:
|
||||
# Thread terminated.
|
||||
pass
|
||||
|
||||
if not started:
|
||||
if _IS_SYNC:
|
||||
thread = threading.Thread(target=self._run, name=self._name)
|
||||
else:
|
||||
thread = threading.Thread(target=self._run_async, name=self._name)
|
||||
thread.daemon = True
|
||||
self._thread = weakref.proxy(thread)
|
||||
_register_executor(self)
|
||||
# Mitigation to RuntimeError firing when thread starts on shutdown
|
||||
# https://github.com/python/cpython/issues/114570
|
||||
try:
|
||||
thread.start()
|
||||
except RuntimeError as e:
|
||||
if "interpreter shutdown" in str(e) or sys.is_finalizing():
|
||||
self._thread = None
|
||||
return
|
||||
raise
|
||||
|
||||
def close(self, dummy: Any = None) -> None:
|
||||
"""Stop. To restart, call open().
|
||||
|
||||
The dummy parameter allows an executor's close method to be a weakref
|
||||
callback; see monitor.py.
|
||||
"""
|
||||
self._stopped = True
|
||||
|
||||
def join(self, timeout: Optional[int] = None) -> None:
|
||||
if self._thread is not None:
|
||||
try:
|
||||
self._thread.join(timeout)
|
||||
except (ReferenceError, RuntimeError):
|
||||
# Thread already terminated, or not yet started.
|
||||
pass
|
||||
|
||||
def wake(self) -> None:
|
||||
"""Execute the target function soon."""
|
||||
self._event = True
|
||||
|
||||
def update_interval(self, new_interval: int) -> None:
|
||||
self._interval = new_interval
|
||||
|
||||
def skip_sleep(self) -> None:
|
||||
self._skip_sleep = True
|
||||
|
||||
async def _should_stop(self) -> bool:
|
||||
async with self._lock:
|
||||
if self._stopped:
|
||||
self._thread_will_exit = True
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _run(self) -> None:
|
||||
while not await self._should_stop():
|
||||
try:
|
||||
if not await self._target():
|
||||
self._stopped = True
|
||||
break
|
||||
except BaseException:
|
||||
async with self._lock:
|
||||
self._stopped = True
|
||||
self._thread_will_exit = True
|
||||
|
||||
raise
|
||||
|
||||
if self._skip_sleep:
|
||||
self._skip_sleep = False
|
||||
else:
|
||||
deadline = time.monotonic() + self._interval
|
||||
while not self._stopped and time.monotonic() < deadline:
|
||||
await asyncio.sleep(self._min_interval)
|
||||
if self._event:
|
||||
break # Early wake.
|
||||
|
||||
self._event = False
|
||||
|
||||
|
||||
# _EXECUTORS has a weakref to each running PeriodicExecutor. Once started,
|
||||
# an executor is kept alive by a strong reference from its thread and perhaps
|
||||
# from other objects. When the thread dies and all other referrers are freed,
|
||||
# the executor is freed and removed from _EXECUTORS. If any threads are
|
||||
# running when the interpreter begins to shut down, we try to halt and join
|
||||
# them to avoid spurious errors.
|
||||
_EXECUTORS = set()
|
||||
|
||||
|
||||
def _register_executor(executor: PeriodicExecutor) -> None:
|
||||
ref = weakref.ref(executor, _on_executor_deleted)
|
||||
_EXECUTORS.add(ref)
|
||||
|
||||
|
||||
def _on_executor_deleted(ref: weakref.ReferenceType[PeriodicExecutor]) -> None:
|
||||
_EXECUTORS.remove(ref)
|
||||
|
||||
|
||||
def _shutdown_executors() -> None:
|
||||
if _EXECUTORS is None:
|
||||
return
|
||||
|
||||
# Copy the set. Stopping threads has the side effect of removing executors.
|
||||
executors = list(_EXECUTORS)
|
||||
|
||||
# First signal all executors to close...
|
||||
for ref in executors:
|
||||
executor = ref()
|
||||
if executor:
|
||||
executor.close()
|
||||
|
||||
# ...then try to join them.
|
||||
for ref in executors:
|
||||
executor = ref()
|
||||
if executor:
|
||||
executor.join(1)
|
||||
|
||||
executor = None
|
||||
1691
venv/lib/python3.11/site-packages/pymongo/asynchronous/pool.py
Normal file
1691
venv/lib/python3.11/site-packages/pymongo/asynchronous/pool.py
Normal file
File diff suppressed because it is too large
Load Diff
383
venv/lib/python3.11/site-packages/pymongo/asynchronous/server.py
Normal file
383
venv/lib/python3.11/site-packages/pymongo/asynchronous/server.py
Normal file
@@ -0,0 +1,383 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Communicate with one MongoDB server in a topology."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
Callable,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from bson import _decode_all_selective
|
||||
from pymongo.asynchronous.helpers import _handle_reauth
|
||||
from pymongo.errors import NotPrimaryError, OperationFailure
|
||||
from pymongo.helpers_shared import _check_command_response
|
||||
from pymongo.logger import (
|
||||
_COMMAND_LOGGER,
|
||||
_SDAM_LOGGER,
|
||||
_CommandStatusMessage,
|
||||
_debug_log,
|
||||
_SDAMStatusMessage,
|
||||
)
|
||||
from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query
|
||||
from pymongo.response import PinnedResponse, Response
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from queue import Queue
|
||||
from weakref import ReferenceType
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler
|
||||
from pymongo.asynchronous.monitor import Monitor
|
||||
from pymongo.asynchronous.pool import AsyncConnection, Pool
|
||||
from pymongo.monitoring import _EventListeners
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.typings import _DocumentOut
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}}
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(
|
||||
self,
|
||||
server_description: ServerDescription,
|
||||
pool: Pool,
|
||||
monitor: Monitor,
|
||||
topology_id: Optional[ObjectId] = None,
|
||||
listeners: Optional[_EventListeners] = None,
|
||||
events: Optional[ReferenceType[Queue]] = None,
|
||||
) -> None:
|
||||
"""Represent one MongoDB server."""
|
||||
self._description = server_description
|
||||
self._pool = pool
|
||||
self._monitor = monitor
|
||||
self._topology_id = topology_id
|
||||
self._publish = listeners is not None and listeners.enabled_for_server
|
||||
self._listener = listeners
|
||||
self._events = None
|
||||
if self._publish:
|
||||
self._events = events() # type: ignore[misc]
|
||||
|
||||
async def open(self) -> None:
|
||||
"""Start monitoring, or restart after a fork.
|
||||
|
||||
Multiple calls have no effect.
|
||||
"""
|
||||
if not self._pool.opts.load_balanced:
|
||||
self._monitor.open()
|
||||
|
||||
async def reset(self, service_id: Optional[ObjectId] = None) -> None:
|
||||
"""Clear the connection pool."""
|
||||
await self.pool.reset(service_id)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clear the connection pool and stop the monitor.
|
||||
|
||||
Reconnect with open().
|
||||
"""
|
||||
if self._publish:
|
||||
assert self._listener is not None
|
||||
assert self._events is not None
|
||||
self._events.put(
|
||||
(
|
||||
self._listener.publish_server_closed,
|
||||
(self._description.address, self._topology_id),
|
||||
)
|
||||
)
|
||||
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_SDAM_LOGGER,
|
||||
topologyId=self._topology_id,
|
||||
serverHost=self._description.address[0],
|
||||
serverPort=self._description.address[1],
|
||||
message=_SDAMStatusMessage.STOP_SERVER,
|
||||
)
|
||||
|
||||
await self._monitor.close()
|
||||
await self._pool.close()
|
||||
|
||||
def request_check(self) -> None:
|
||||
"""Check the server's state soon."""
|
||||
self._monitor.request_check()
|
||||
|
||||
async def operation_to_command(
|
||||
self, operation: Union[_Query, _GetMore], conn: AsyncConnection, apply_timeout: bool = False
|
||||
) -> tuple[dict[str, Any], str]:
|
||||
cmd, db = operation.as_command(conn, apply_timeout)
|
||||
# Support auto encryption
|
||||
if operation.client._encrypter and not operation.client._encrypter._bypass_auto_encryption:
|
||||
cmd = await operation.client._encrypter.encrypt( # type: ignore[misc, assignment]
|
||||
operation.db, cmd, operation.codec_options
|
||||
)
|
||||
operation.update_command(cmd)
|
||||
|
||||
return cmd, db
|
||||
|
||||
@_handle_reauth
|
||||
async def run_operation(
|
||||
self,
|
||||
conn: AsyncConnection,
|
||||
operation: Union[_Query, _GetMore],
|
||||
read_preference: _ServerMode,
|
||||
listeners: Optional[_EventListeners],
|
||||
unpack_res: Callable[..., list[_DocumentOut]],
|
||||
client: AsyncMongoClient,
|
||||
) -> Response:
|
||||
"""Run a _Query or _GetMore operation and return a Response object.
|
||||
|
||||
This method is used only to run _Query/_GetMore operations from
|
||||
cursors.
|
||||
Can raise ConnectionFailure, OperationFailure, etc.
|
||||
|
||||
:param conn: An AsyncConnection instance.
|
||||
:param operation: A _Query or _GetMore object.
|
||||
:param read_preference: The read preference to use.
|
||||
:param listeners: Instance of _EventListeners or None.
|
||||
:param unpack_res: A callable that decodes the wire protocol response.
|
||||
:param client: An AsyncMongoClient instance.
|
||||
"""
|
||||
assert listeners is not None
|
||||
publish = listeners.enabled_for_commands
|
||||
start = datetime.now()
|
||||
|
||||
use_cmd = operation.use_command(conn)
|
||||
more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come
|
||||
cmd, dbn = await self.operation_to_command(operation, conn, use_cmd)
|
||||
if more_to_come:
|
||||
request_id = 0
|
||||
else:
|
||||
message = operation.get_message(read_preference, conn, use_cmd)
|
||||
request_id, data, max_doc_size = self._split_message(message)
|
||||
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.STARTED,
|
||||
command=cmd,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=dbn,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
)
|
||||
|
||||
if publish:
|
||||
if "$db" not in cmd:
|
||||
cmd["$db"] = dbn
|
||||
assert listeners is not None
|
||||
listeners.publish_command_start(
|
||||
cmd,
|
||||
dbn,
|
||||
request_id,
|
||||
conn.address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
)
|
||||
|
||||
try:
|
||||
if more_to_come:
|
||||
reply = await conn.receive_message(None)
|
||||
else:
|
||||
await conn.send_message(data, max_doc_size)
|
||||
reply = await conn.receive_message(request_id)
|
||||
|
||||
# Unpack and check for command errors.
|
||||
if use_cmd:
|
||||
user_fields = _CURSOR_DOC_FIELDS
|
||||
legacy_response = False
|
||||
else:
|
||||
user_fields = None
|
||||
legacy_response = True
|
||||
docs = unpack_res(
|
||||
reply,
|
||||
operation.cursor_id,
|
||||
operation.codec_options,
|
||||
legacy_response=legacy_response,
|
||||
user_fields=user_fields,
|
||||
)
|
||||
if use_cmd:
|
||||
first = docs[0]
|
||||
await operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type]
|
||||
_check_command_response(first, conn.max_wire_version)
|
||||
except Exception as exc:
|
||||
duration = datetime.now() - start
|
||||
if isinstance(exc, (NotPrimaryError, OperationFailure)):
|
||||
failure: _DocumentOut = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.FAILED,
|
||||
durationMS=duration,
|
||||
failure=failure,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=dbn,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
isServerSideError=isinstance(exc, OperationFailure),
|
||||
)
|
||||
if publish:
|
||||
assert listeners is not None
|
||||
listeners.publish_command_failure(
|
||||
duration,
|
||||
failure,
|
||||
operation.name,
|
||||
request_id,
|
||||
conn.address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
database_name=dbn,
|
||||
)
|
||||
raise
|
||||
duration = datetime.now() - start
|
||||
# Must publish in find / getMore / explain command response
|
||||
# format.
|
||||
if use_cmd:
|
||||
res = docs[0]
|
||||
elif operation.name == "explain":
|
||||
res = docs[0] if docs else {}
|
||||
else:
|
||||
res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} # type: ignore[union-attr]
|
||||
if operation.name == "find":
|
||||
res["cursor"]["firstBatch"] = docs
|
||||
else:
|
||||
res["cursor"]["nextBatch"] = docs
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_debug_log(
|
||||
_COMMAND_LOGGER,
|
||||
clientId=client._topology_settings._topology_id,
|
||||
message=_CommandStatusMessage.SUCCEEDED,
|
||||
durationMS=duration,
|
||||
reply=res,
|
||||
commandName=next(iter(cmd)),
|
||||
databaseName=dbn,
|
||||
requestId=request_id,
|
||||
operationId=request_id,
|
||||
driverConnectionId=conn.id,
|
||||
serverConnectionId=conn.server_connection_id,
|
||||
serverHost=conn.address[0],
|
||||
serverPort=conn.address[1],
|
||||
serviceId=conn.service_id,
|
||||
)
|
||||
if publish:
|
||||
assert listeners is not None
|
||||
listeners.publish_command_success(
|
||||
duration,
|
||||
res,
|
||||
operation.name,
|
||||
request_id,
|
||||
conn.address,
|
||||
conn.server_connection_id,
|
||||
service_id=conn.service_id,
|
||||
database_name=dbn,
|
||||
)
|
||||
|
||||
# Decrypt response.
|
||||
client = operation.client # type: ignore[assignment]
|
||||
if client and client._encrypter:
|
||||
if use_cmd:
|
||||
decrypted = await client._encrypter.decrypt(reply.raw_command_response())
|
||||
docs = _decode_all_selective(decrypted, operation.codec_options, user_fields)
|
||||
|
||||
response: Response
|
||||
|
||||
if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type]
|
||||
conn.pin_cursor()
|
||||
if isinstance(reply, _OpMsg):
|
||||
# In OP_MSG, the server keeps sending only if the
|
||||
# more_to_come flag is set.
|
||||
more_to_come = reply.more_to_come
|
||||
else:
|
||||
# In OP_REPLY, the server keeps sending until cursor_id is 0.
|
||||
more_to_come = bool(operation.exhaust and reply.cursor_id)
|
||||
if operation.conn_mgr:
|
||||
operation.conn_mgr.update_exhaust(more_to_come)
|
||||
response = PinnedResponse(
|
||||
data=reply,
|
||||
address=self._description.address,
|
||||
conn=conn,
|
||||
duration=duration,
|
||||
request_id=request_id,
|
||||
from_command=use_cmd,
|
||||
docs=docs,
|
||||
more_to_come=more_to_come,
|
||||
)
|
||||
else:
|
||||
response = Response(
|
||||
data=reply,
|
||||
address=self._description.address,
|
||||
duration=duration,
|
||||
request_id=request_id,
|
||||
from_command=use_cmd,
|
||||
docs=docs,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def checkout(
|
||||
self, handler: Optional[_MongoClientErrorHandler] = None
|
||||
) -> AsyncContextManager[AsyncConnection]:
|
||||
return self.pool.checkout(handler)
|
||||
|
||||
@property
|
||||
def description(self) -> ServerDescription:
|
||||
return self._description
|
||||
|
||||
@description.setter
|
||||
def description(self, server_description: ServerDescription) -> None:
|
||||
assert server_description.address == self._description.address
|
||||
self._description = server_description
|
||||
|
||||
@property
|
||||
def pool(self) -> Pool:
|
||||
return self._pool
|
||||
|
||||
def _split_message(
|
||||
self, message: Union[tuple[int, Any], tuple[int, Any, int]]
|
||||
) -> tuple[int, Any, int]:
|
||||
"""Return request_id, data, max_doc_size.
|
||||
|
||||
:param message: (request_id, data, max_doc_size) or (request_id, data)
|
||||
"""
|
||||
if len(message) == 3:
|
||||
return message # type: ignore[return-value]
|
||||
else:
|
||||
# get_more and kill_cursors messages don't include BSON documents.
|
||||
request_id, data = message # type: ignore[misc]
|
||||
return request_id, data, 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} {self._description!r}>"
|
||||
@@ -0,0 +1,172 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Represent MongoClient's configuration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import traceback
|
||||
from typing import Any, Collection, Optional, Type, Union
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo import common
|
||||
from pymongo.asynchronous import monitor, pool
|
||||
from pymongo.asynchronous.pool import Pool
|
||||
from pymongo.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE, _ServerSelector
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class TopologySettings:
|
||||
def __init__(
|
||||
self,
|
||||
seeds: Optional[Collection[tuple[str, int]]] = None,
|
||||
replica_set_name: Optional[str] = None,
|
||||
pool_class: Optional[Type[Pool]] = None,
|
||||
pool_options: Optional[PoolOptions] = None,
|
||||
monitor_class: Optional[Type[monitor.Monitor]] = None,
|
||||
condition_class: Optional[Type[threading.Condition]] = None,
|
||||
local_threshold_ms: int = LOCAL_THRESHOLD_MS,
|
||||
server_selection_timeout: int = SERVER_SELECTION_TIMEOUT,
|
||||
heartbeat_frequency: int = common.HEARTBEAT_FREQUENCY,
|
||||
server_selector: Optional[_ServerSelector] = None,
|
||||
fqdn: Optional[str] = None,
|
||||
direct_connection: Optional[bool] = False,
|
||||
load_balanced: Optional[bool] = None,
|
||||
srv_service_name: str = common.SRV_SERVICE_NAME,
|
||||
srv_max_hosts: int = 0,
|
||||
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
|
||||
):
|
||||
"""Represent MongoClient's configuration.
|
||||
|
||||
Take a list of (host, port) pairs and optional replica set name.
|
||||
"""
|
||||
if heartbeat_frequency < common.MIN_HEARTBEAT_INTERVAL:
|
||||
raise ConfigurationError(
|
||||
"heartbeatFrequencyMS cannot be less than %d"
|
||||
% (common.MIN_HEARTBEAT_INTERVAL * 1000,)
|
||||
)
|
||||
|
||||
self._seeds: Collection[tuple[str, int]] = seeds or [("localhost", 27017)]
|
||||
self._replica_set_name = replica_set_name
|
||||
self._pool_class: Type[Pool] = pool_class or pool.Pool
|
||||
self._pool_options: PoolOptions = pool_options or PoolOptions()
|
||||
self._monitor_class: Type[monitor.Monitor] = monitor_class or monitor.Monitor
|
||||
self._condition_class: Type[threading.Condition] = condition_class or threading.Condition
|
||||
self._local_threshold_ms = local_threshold_ms
|
||||
self._server_selection_timeout = server_selection_timeout
|
||||
self._server_selector = server_selector
|
||||
self._fqdn = fqdn
|
||||
self._heartbeat_frequency = heartbeat_frequency
|
||||
self._direct = direct_connection
|
||||
self._load_balanced = load_balanced
|
||||
self._srv_service_name = srv_service_name
|
||||
self._srv_max_hosts = srv_max_hosts or 0
|
||||
self._server_monitoring_mode = server_monitoring_mode
|
||||
|
||||
self._topology_id = ObjectId()
|
||||
# Store the allocation traceback to catch unclosed clients in the
|
||||
# test suite.
|
||||
self._stack = "".join(traceback.format_stack()[:-2])
|
||||
|
||||
@property
|
||||
def seeds(self) -> Collection[tuple[str, int]]:
|
||||
"""List of server addresses."""
|
||||
return self._seeds
|
||||
|
||||
@property
|
||||
def replica_set_name(self) -> Optional[str]:
|
||||
return self._replica_set_name
|
||||
|
||||
@property
|
||||
def pool_class(self) -> Type[Pool]:
|
||||
return self._pool_class
|
||||
|
||||
@property
|
||||
def pool_options(self) -> PoolOptions:
|
||||
return self._pool_options
|
||||
|
||||
@property
|
||||
def monitor_class(self) -> Type[monitor.Monitor]:
|
||||
return self._monitor_class
|
||||
|
||||
@property
|
||||
def condition_class(self) -> Type[threading.Condition]:
|
||||
return self._condition_class
|
||||
|
||||
@property
|
||||
def local_threshold_ms(self) -> int:
|
||||
return self._local_threshold_ms
|
||||
|
||||
@property
|
||||
def server_selection_timeout(self) -> int:
|
||||
return self._server_selection_timeout
|
||||
|
||||
@property
|
||||
def server_selector(self) -> Optional[_ServerSelector]:
|
||||
return self._server_selector
|
||||
|
||||
@property
|
||||
def heartbeat_frequency(self) -> int:
|
||||
return self._heartbeat_frequency
|
||||
|
||||
@property
|
||||
def fqdn(self) -> Optional[str]:
|
||||
return self._fqdn
|
||||
|
||||
@property
|
||||
def direct(self) -> Optional[bool]:
|
||||
"""Connect directly to a single server, or use a set of servers?
|
||||
|
||||
True if there is one seed and no replica_set_name.
|
||||
"""
|
||||
return self._direct
|
||||
|
||||
@property
|
||||
def load_balanced(self) -> Optional[bool]:
|
||||
"""True if the client was configured to connect to a load balancer."""
|
||||
return self._load_balanced
|
||||
|
||||
@property
|
||||
def srv_service_name(self) -> str:
|
||||
"""The srvServiceName."""
|
||||
return self._srv_service_name
|
||||
|
||||
@property
|
||||
def srv_max_hosts(self) -> int:
|
||||
"""The srvMaxHosts."""
|
||||
return self._srv_max_hosts
|
||||
|
||||
@property
|
||||
def server_monitoring_mode(self) -> str:
|
||||
"""The serverMonitoringMode."""
|
||||
return self._server_monitoring_mode
|
||||
|
||||
def get_topology_type(self) -> int:
|
||||
if self.load_balanced:
|
||||
return TOPOLOGY_TYPE.LoadBalanced
|
||||
elif self.direct:
|
||||
return TOPOLOGY_TYPE.Single
|
||||
elif self.replica_set_name is not None:
|
||||
return TOPOLOGY_TYPE.ReplicaSetNoPrimary
|
||||
else:
|
||||
return TOPOLOGY_TYPE.Unknown
|
||||
|
||||
def get_server_descriptions(self) -> dict[Union[tuple[str, int], Any], ServerDescription]:
|
||||
"""Initial dict of (address, ServerDescription) for all seeds."""
|
||||
return {address: ServerDescription(address) for address in self.seeds}
|
||||
1099
venv/lib/python3.11/site-packages/pymongo/asynchronous/topology.py
Normal file
1099
venv/lib/python3.11/site-packages/pymongo/asynchronous/topology.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user