fix
All checks were successful
Deploy Prod / Build (pull_request) Successful in 9s
Deploy Prod / Push (pull_request) Successful in 12s
Deploy Prod / Deploy prod (pull_request) Successful in 10s

This commit is contained in:
Egor Matveev
2024-12-28 22:48:16 +03:00
parent c1249bfcd0
commit 6c6a549aff
2532 changed files with 562109 additions and 1 deletions

View File

@@ -0,0 +1,257 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Perform aggregation operations on a collection or database."""
from __future__ import annotations
from collections.abc import Callable, Mapping, MutableMapping
from typing import TYPE_CHECKING, Any, Optional, Union
from pymongo import common
from pymongo.collation import validate_collation_or_none
from pymongo.errors import ConfigurationError
from pymongo.read_preferences import ReadPreference, _AggWritePref
if TYPE_CHECKING:
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.asynchronous.server import Server
from pymongo.read_preferences import _ServerMode
from pymongo.typings import _DocumentType, _Pipeline
_IS_SYNC = False
class _AggregationCommand:
"""The internal abstract base class for aggregation cursors.
Should not be called directly by application developers. Use
:meth:`pymongo.asynchronous.collection.AsyncCollection.aggregate`, or
:meth:`pymongo.asynchronous.database.AsyncDatabase.aggregate` instead.
"""
def __init__(
self,
target: Union[AsyncDatabase, AsyncCollection],
cursor_class: type[AsyncCommandCursor],
pipeline: _Pipeline,
options: MutableMapping[str, Any],
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
user_fields: Optional[MutableMapping[str, Any]] = None,
result_processor: Optional[Callable[[Mapping[str, Any], AsyncConnection], None]] = None,
comment: Any = None,
) -> None:
if "explain" in options:
raise ConfigurationError(
"The explain option is not supported. Use AsyncDatabase.command instead."
)
self._target = target
pipeline = common.validate_list("pipeline", pipeline)
self._pipeline = pipeline
self._performs_write = False
if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]):
self._performs_write = True
common.validate_is_mapping("options", options)
if let is not None:
common.validate_is_mapping("let", let)
options["let"] = let
if comment is not None:
options["comment"] = comment
self._options = options
# This is the batchSize that will be used for setting the initial
# batchSize for the cursor, as well as the subsequent getMores.
self._batch_size = common.validate_non_negative_integer_or_none(
"batchSize", self._options.pop("batchSize", None)
)
# If the cursor option is already specified, avoid overriding it.
self._options.setdefault("cursor", {})
# If the pipeline performs a write, we ignore the initial batchSize
# since the server doesn't return results in this case.
if self._batch_size is not None and not self._performs_write:
self._options["cursor"]["batchSize"] = self._batch_size
self._cursor_class = cursor_class
self._explicit_session = explicit_session
self._user_fields = user_fields
self._result_processor = result_processor
self._collation = validate_collation_or_none(options.pop("collation", None))
self._max_await_time_ms = options.pop("maxAwaitTimeMS", None)
self._write_preference: Optional[_AggWritePref] = None
@property
def _aggregation_target(self) -> Union[str, int]:
"""The argument to pass to the aggregate command."""
raise NotImplementedError
@property
def _cursor_namespace(self) -> str:
"""The namespace in which the aggregate command is run."""
raise NotImplementedError
def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> AsyncCollection:
"""The AsyncCollection used for the aggregate command cursor."""
raise NotImplementedError
@property
def _database(self) -> AsyncDatabase:
"""The database against which the aggregation command is run."""
raise NotImplementedError
def get_read_preference(
self, session: Optional[AsyncClientSession]
) -> Union[_AggWritePref, _ServerMode]:
if self._write_preference:
return self._write_preference
pref = self._target._read_preference_for(session)
if self._performs_write and pref != ReadPreference.PRIMARY:
self._write_preference = pref = _AggWritePref(pref) # type: ignore[assignment]
return pref
async def get_cursor(
self,
session: Optional[AsyncClientSession],
server: Server,
conn: AsyncConnection,
read_preference: _ServerMode,
) -> AsyncCommandCursor[_DocumentType]:
# Serialize command.
cmd = {"aggregate": self._aggregation_target, "pipeline": self._pipeline}
cmd.update(self._options)
# Apply this target's read concern if:
# readConcern has not been specified as a kwarg and either
# - server version is >= 4.2 or
# - server version is >= 3.2 and pipeline doesn't use $out
if ("readConcern" not in cmd) and (
not self._performs_write or (conn.max_wire_version >= 8)
):
read_concern = self._target.read_concern
else:
read_concern = None
# Apply this target's write concern if:
# writeConcern has not been specified as a kwarg and pipeline doesn't
# perform a write operation
if "writeConcern" not in cmd and self._performs_write:
write_concern = self._target._write_concern_for(session)
else:
write_concern = None
# Run command.
result = await conn.command(
self._database.name,
cmd,
read_preference,
self._target.codec_options,
parse_write_concern_error=True,
read_concern=read_concern,
write_concern=write_concern,
collation=self._collation,
session=session,
client=self._database.client,
user_fields=self._user_fields,
)
if self._result_processor:
self._result_processor(result, conn)
# Extract cursor from result or mock/fake one if necessary.
if "cursor" in result:
cursor = result["cursor"]
else:
# Unacknowledged $out/$merge write. Fake a cursor.
cursor = {
"id": 0,
"firstBatch": result.get("result", []),
"ns": self._cursor_namespace,
}
# Create and return cursor instance.
cmd_cursor = self._cursor_class(
self._cursor_collection(cursor),
cursor,
conn.address,
batch_size=self._batch_size or 0,
max_await_time_ms=self._max_await_time_ms,
session=session,
explicit_session=self._explicit_session,
comment=self._options.get("comment"),
)
await cmd_cursor._maybe_pin_connection(conn)
return cmd_cursor
class _CollectionAggregationCommand(_AggregationCommand):
_target: AsyncCollection
@property
def _aggregation_target(self) -> str:
return self._target.name
@property
def _cursor_namespace(self) -> str:
return self._target.full_name
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection:
"""The AsyncCollection used for the aggregate command cursor."""
return self._target
@property
def _database(self) -> AsyncDatabase:
return self._target.database
class _CollectionRawAggregationCommand(_CollectionAggregationCommand):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# For raw-batches, we set the initial batchSize for the cursor to 0.
if not self._performs_write:
self._options["cursor"]["batchSize"] = 0
class _DatabaseAggregationCommand(_AggregationCommand):
_target: AsyncDatabase
@property
def _aggregation_target(self) -> int:
return 1
@property
def _cursor_namespace(self) -> str:
return f"{self._target.name}.$cmd.aggregate"
@property
def _database(self) -> AsyncDatabase:
return self._target
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection:
"""The AsyncCollection used for the aggregate command cursor."""
# AsyncCollection level aggregate may not always return the "ns" field
# according to our MockupDB tests. Let's handle that case for db level
# aggregate too by defaulting to the <db>.$cmd.aggregate namespace.
_, collname = cursor.get("ns", self._cursor_namespace).split(".", 1)
return self._database[collname]

View File

@@ -0,0 +1,457 @@
# Copyright 2013-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Authentication helpers."""
from __future__ import annotations
import functools
import hashlib
import hmac
import socket
from base64 import standard_b64decode, standard_b64encode
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Mapping,
MutableMapping,
Optional,
cast,
)
from urllib.parse import quote
from bson.binary import Binary
from pymongo.asynchronous.auth_aws import _authenticate_aws
from pymongo.asynchronous.auth_oidc import (
_authenticate_oidc,
_get_authenticator,
)
from pymongo.auth_shared import (
MongoCredential,
_authenticate_scram_start,
_parse_scram_response,
_xor,
)
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.saslprep import saslprep
if TYPE_CHECKING:
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.hello import Hello
HAVE_KERBEROS = True
_USE_PRINCIPAL = False
try:
import winkerberos as kerberos # type:ignore[import]
if tuple(map(int, kerberos.__version__.split(".")[:2])) >= (0, 5):
_USE_PRINCIPAL = True
except ImportError:
try:
import kerberos # type:ignore[import]
except ImportError:
HAVE_KERBEROS = False
_IS_SYNC = False
async def _authenticate_scram(
credentials: MongoCredential, conn: AsyncConnection, mechanism: str
) -> None:
"""Authenticate using SCRAM."""
username = credentials.username
if mechanism == "SCRAM-SHA-256":
digest = "sha256"
digestmod = hashlib.sha256
data = saslprep(credentials.password).encode("utf-8")
else:
digest = "sha1"
digestmod = hashlib.sha1
data = _password_digest(username, credentials.password).encode("utf-8")
source = credentials.source
cache = credentials.cache
# Make local
_hmac = hmac.HMAC
ctx = conn.auth_ctx
if ctx and ctx.speculate_succeeded():
assert isinstance(ctx, _ScramContext)
assert ctx.scram_data is not None
nonce, first_bare = ctx.scram_data
res = ctx.speculative_authenticate
else:
nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism)
res = await conn.command(source, cmd)
assert res is not None
server_first = res["payload"]
parsed = _parse_scram_response(server_first)
iterations = int(parsed[b"i"])
if iterations < 4096:
raise OperationFailure("Server returned an invalid iteration count.")
salt = parsed[b"s"]
rnonce = parsed[b"r"]
if not rnonce.startswith(nonce):
raise OperationFailure("Server returned an invalid nonce.")
without_proof = b"c=biws,r=" + rnonce
if cache.data:
client_key, server_key, csalt, citerations = cache.data
else:
client_key, server_key, csalt, citerations = None, None, None, None
# Salt and / or iterations could change for a number of different
# reasons. Either changing invalidates the cache.
if not client_key or salt != csalt or iterations != citerations:
salted_pass = hashlib.pbkdf2_hmac(digest, data, standard_b64decode(salt), iterations)
client_key = _hmac(salted_pass, b"Client Key", digestmod).digest()
server_key = _hmac(salted_pass, b"Server Key", digestmod).digest()
cache.data = (client_key, server_key, salt, iterations)
stored_key = digestmod(client_key).digest()
auth_msg = b",".join((first_bare, server_first, without_proof))
client_sig = _hmac(stored_key, auth_msg, digestmod).digest()
client_proof = b"p=" + standard_b64encode(_xor(client_key, client_sig))
client_final = b",".join((without_proof, client_proof))
server_sig = standard_b64encode(_hmac(server_key, auth_msg, digestmod).digest())
cmd = {
"saslContinue": 1,
"conversationId": res["conversationId"],
"payload": Binary(client_final),
}
res = await conn.command(source, cmd)
parsed = _parse_scram_response(res["payload"])
if not hmac.compare_digest(parsed[b"v"], server_sig):
raise OperationFailure("Server returned an invalid signature.")
# A third empty challenge may be required if the server does not support
# skipEmptyExchange: SERVER-44857.
if not res["done"]:
cmd = {
"saslContinue": 1,
"conversationId": res["conversationId"],
"payload": Binary(b""),
}
res = await conn.command(source, cmd)
if not res["done"]:
raise OperationFailure("SASL conversation failed to complete.")
def _password_digest(username: str, password: str) -> str:
"""Get a password digest to use for authentication."""
if not isinstance(password, str):
raise TypeError("password must be an instance of str")
if len(password) == 0:
raise ValueError("password can't be empty")
if not isinstance(username, str):
raise TypeError("username must be an instance of str")
md5hash = hashlib.md5() # noqa: S324
data = f"{username}:mongo:{password}"
md5hash.update(data.encode("utf-8"))
return md5hash.hexdigest()
def _auth_key(nonce: str, username: str, password: str) -> str:
"""Get an auth key to use for authentication."""
digest = _password_digest(username, password)
md5hash = hashlib.md5() # noqa: S324
data = f"{nonce}{username}{digest}"
md5hash.update(data.encode("utf-8"))
return md5hash.hexdigest()
def _canonicalize_hostname(hostname: str) -> str:
"""Canonicalize hostname following MIT-krb5 behavior."""
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
)[0]
try:
name = socket.getnameinfo(sockaddr, socket.NI_NAMEREQD)
except socket.gaierror:
return canonname.lower()
return name[0].lower()
async def _authenticate_gssapi(credentials: MongoCredential, conn: AsyncConnection) -> None:
"""Authenticate using GSSAPI."""
if not HAVE_KERBEROS:
raise ConfigurationError(
'The "kerberos" module must be installed to use GSSAPI authentication.'
)
try:
username = credentials.username
password = credentials.password
props = credentials.mechanism_properties
# Starting here and continuing through the while loop below - establish
# the security context. See RFC 4752, Section 3.1, first paragraph.
host = conn.address[0]
if props.canonicalize_host_name:
host = _canonicalize_hostname(host)
service = props.service_name + "@" + host
if props.service_realm is not None:
service = service + "@" + props.service_realm
if password is not None:
if _USE_PRINCIPAL:
# Note that, though we use unquote_plus for unquoting URI
# options, we use quote here. Microsoft's UrlUnescape (used
# by WinKerberos) doesn't support +.
principal = ":".join((quote(username), quote(password)))
result, ctx = kerberos.authGSSClientInit(
service, principal, gssflags=kerberos.GSS_C_MUTUAL_FLAG
)
else:
if "@" in username:
user, domain = username.split("@", 1)
else:
user, domain = username, None
result, ctx = kerberos.authGSSClientInit(
service,
gssflags=kerberos.GSS_C_MUTUAL_FLAG,
user=user,
domain=domain,
password=password,
)
else:
result, ctx = kerberos.authGSSClientInit(service, gssflags=kerberos.GSS_C_MUTUAL_FLAG)
if result != kerberos.AUTH_GSS_COMPLETE:
raise OperationFailure("Kerberos context failed to initialize.")
try:
# pykerberos uses a weird mix of exceptions and return values
# to indicate errors.
# 0 == continue, 1 == complete, -1 == error
# Only authGSSClientStep can return 0.
if kerberos.authGSSClientStep(ctx, "") != 0:
raise OperationFailure("Unknown kerberos failure in step function.")
# Start a SASL conversation with mongod/s
# Note: pykerberos deals with base64 encoded byte strings.
# Since mongo accepts base64 strings as the payload we don't
# have to use bson.binary.Binary.
payload = kerberos.authGSSClientResponse(ctx)
cmd = {
"saslStart": 1,
"mechanism": "GSSAPI",
"payload": payload,
"autoAuthorize": 1,
}
response = await conn.command("$external", cmd)
# Limit how many times we loop to catch protocol / library issues
for _ in range(10):
result = kerberos.authGSSClientStep(ctx, str(response["payload"]))
if result == -1:
raise OperationFailure("Unknown kerberos failure in step function.")
payload = kerberos.authGSSClientResponse(ctx) or ""
cmd = {
"saslContinue": 1,
"conversationId": response["conversationId"],
"payload": payload,
}
response = await conn.command("$external", cmd)
if result == kerberos.AUTH_GSS_COMPLETE:
break
else:
raise OperationFailure("Kerberos authentication failed to complete.")
# Once the security context is established actually authenticate.
# See RFC 4752, Section 3.1, last two paragraphs.
if kerberos.authGSSClientUnwrap(ctx, str(response["payload"])) != 1:
raise OperationFailure("Unknown kerberos failure during GSS_Unwrap step.")
if kerberos.authGSSClientWrap(ctx, kerberos.authGSSClientResponse(ctx), username) != 1:
raise OperationFailure("Unknown kerberos failure during GSS_Wrap step.")
payload = kerberos.authGSSClientResponse(ctx)
cmd = {
"saslContinue": 1,
"conversationId": response["conversationId"],
"payload": payload,
}
await conn.command("$external", cmd)
finally:
kerberos.authGSSClientClean(ctx)
except kerberos.KrbError as exc:
raise OperationFailure(str(exc)) from None
async def _authenticate_plain(credentials: MongoCredential, conn: AsyncConnection) -> None:
"""Authenticate using SASL PLAIN (RFC 4616)"""
source = credentials.source
username = credentials.username
password = credentials.password
payload = (f"\x00{username}\x00{password}").encode()
cmd = {
"saslStart": 1,
"mechanism": "PLAIN",
"payload": Binary(payload),
"autoAuthorize": 1,
}
await conn.command(source, cmd)
async def _authenticate_x509(credentials: MongoCredential, conn: AsyncConnection) -> None:
"""Authenticate using MONGODB-X509."""
ctx = conn.auth_ctx
if ctx and ctx.speculate_succeeded():
# MONGODB-X509 is done after the speculative auth step.
return
cmd = _X509Context(credentials, conn.address).speculate_command()
await conn.command("$external", cmd)
async def _authenticate_mongo_cr(credentials: MongoCredential, conn: AsyncConnection) -> None:
"""Authenticate using MONGODB-CR."""
source = credentials.source
username = credentials.username
password = credentials.password
# Get a nonce
response = await conn.command(source, {"getnonce": 1})
nonce = response["nonce"]
key = _auth_key(nonce, username, password)
# Actually authenticate
query = {"authenticate": 1, "user": username, "nonce": nonce, "key": key}
await conn.command(source, query)
async def _authenticate_default(credentials: MongoCredential, conn: AsyncConnection) -> None:
if conn.max_wire_version >= 7:
if conn.negotiated_mechs:
mechs = conn.negotiated_mechs
else:
source = credentials.source
cmd = conn.hello_cmd()
cmd["saslSupportedMechs"] = source + "." + credentials.username
mechs = (await conn.command(source, cmd, publish_events=False)).get(
"saslSupportedMechs", []
)
if "SCRAM-SHA-256" in mechs:
return await _authenticate_scram(credentials, conn, "SCRAM-SHA-256")
else:
return await _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
else:
return await _authenticate_scram(credentials, conn, "SCRAM-SHA-1")
_AUTH_MAP: Mapping[str, Callable[..., Coroutine[Any, Any, None]]] = {
"GSSAPI": _authenticate_gssapi,
"MONGODB-CR": _authenticate_mongo_cr,
"MONGODB-X509": _authenticate_x509,
"MONGODB-AWS": _authenticate_aws,
"MONGODB-OIDC": _authenticate_oidc, # type:ignore[dict-item]
"PLAIN": _authenticate_plain,
"SCRAM-SHA-1": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-1"),
"SCRAM-SHA-256": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-256"),
"DEFAULT": _authenticate_default,
}
class _AuthContext:
def __init__(self, credentials: MongoCredential, address: tuple[str, int]) -> None:
self.credentials = credentials
self.speculative_authenticate: Optional[Mapping[str, Any]] = None
self.address = address
@staticmethod
def from_credentials(
creds: MongoCredential, address: tuple[str, int]
) -> Optional[_AuthContext]:
spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism)
if spec_cls:
return cast(_AuthContext, spec_cls(creds, address))
return None
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
raise NotImplementedError
def parse_response(self, hello: Hello[Mapping[str, Any]]) -> None:
self.speculative_authenticate = hello.speculative_authenticate
def speculate_succeeded(self) -> bool:
return bool(self.speculative_authenticate)
class _ScramContext(_AuthContext):
def __init__(
self, credentials: MongoCredential, address: tuple[str, int], mechanism: str
) -> None:
super().__init__(credentials, address)
self.scram_data: Optional[tuple[bytes, bytes]] = None
self.mechanism = mechanism
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
nonce, first_bare, cmd = _authenticate_scram_start(self.credentials, self.mechanism)
# The 'db' field is included only on the speculative command.
cmd["db"] = self.credentials.source
# Save for later use.
self.scram_data = (nonce, first_bare)
return cmd
class _X509Context(_AuthContext):
def speculate_command(self) -> MutableMapping[str, Any]:
cmd = {"authenticate": 1, "mechanism": "MONGODB-X509"}
if self.credentials.username is not None:
cmd["user"] = self.credentials.username
return cmd
class _OIDCContext(_AuthContext):
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
authenticator = _get_authenticator(self.credentials, self.address)
cmd = authenticator.get_spec_auth_cmd()
if cmd is None:
return None
cmd["db"] = self.credentials.source
return cmd
_SPECULATIVE_AUTH_MAP: Mapping[str, Any] = {
"MONGODB-X509": _X509Context,
"SCRAM-SHA-1": functools.partial(_ScramContext, mechanism="SCRAM-SHA-1"),
"SCRAM-SHA-256": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
"MONGODB-OIDC": _OIDCContext,
"DEFAULT": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"),
}
async def authenticate(
credentials: MongoCredential, conn: AsyncConnection, reauthenticate: bool = False
) -> None:
"""Authenticate connection."""
mechanism = credentials.mechanism
auth_func = _AUTH_MAP[mechanism]
if mechanism == "MONGODB-OIDC":
await _authenticate_oidc(credentials, conn, reauthenticate)
else:
await auth_func(credentials, conn)

View File

@@ -0,0 +1,100 @@
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MONGODB-AWS Authentication helpers."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Type
import bson
from bson.binary import Binary
from pymongo.errors import ConfigurationError, OperationFailure
if TYPE_CHECKING:
from bson.typings import _ReadableBuffer
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.auth_shared import MongoCredential
_IS_SYNC = False
async def _authenticate_aws(credentials: MongoCredential, conn: AsyncConnection) -> None:
"""Authenticate using MONGODB-AWS."""
try:
import pymongo_auth_aws # type:ignore[import]
except ImportError as e:
raise ConfigurationError(
"MONGODB-AWS authentication requires pymongo-auth-aws: "
"install with: python -m pip install 'pymongo[aws]'"
) from e
# Delayed import.
from pymongo_auth_aws.auth import ( # type:ignore[import]
set_cached_credentials,
set_use_cached_credentials,
)
set_use_cached_credentials(True)
if conn.max_wire_version < 9:
raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later")
class AwsSaslContext(pymongo_auth_aws.AwsSaslContext): # type: ignore
# Dependency injection:
def binary_type(self) -> Type[Binary]:
"""Return the bson.binary.Binary type."""
return Binary
def bson_encode(self, doc: Mapping[str, Any]) -> bytes:
"""Encode a dictionary to BSON."""
return bson.encode(doc)
def bson_decode(self, data: _ReadableBuffer) -> Mapping[str, Any]:
"""Decode BSON to a dictionary."""
return bson.decode(data)
try:
ctx = AwsSaslContext(
pymongo_auth_aws.AwsCredential(
credentials.username,
credentials.password,
credentials.mechanism_properties.aws_session_token,
)
)
client_payload = ctx.step(None)
client_first = {"saslStart": 1, "mechanism": "MONGODB-AWS", "payload": client_payload}
server_first = await conn.command("$external", client_first)
res = server_first
# Limit how many times we loop to catch protocol / library issues
for _ in range(10):
client_payload = ctx.step(res["payload"])
cmd = {
"saslContinue": 1,
"conversationId": server_first["conversationId"],
"payload": client_payload,
}
res = await conn.command("$external", cmd)
if res["done"]:
# SASL complete.
break
except pymongo_auth_aws.PyMongoAuthAwsError as exc:
# Clear the cached credentials if we hit a failure in auth.
set_cached_credentials(None)
# Convert to OperationFailure and include pymongo-auth-aws version.
raise OperationFailure(
f"{exc} (pymongo-auth-aws version {pymongo_auth_aws.__version__})"
) from None
except Exception:
# Clear the cached credentials if we hit a failure in auth.
set_cached_credentials(None)
raise

View File

@@ -0,0 +1,294 @@
# Copyright 2023-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MONGODB-OIDC Authentication helpers."""
from __future__ import annotations
import threading
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union
import bson
from bson.binary import Binary
from pymongo._csot import remaining
from pymongo.auth_oidc_shared import (
CALLBACK_VERSION,
HUMAN_CALLBACK_TIMEOUT_SECONDS,
MACHINE_CALLBACK_TIMEOUT_SECONDS,
TIME_BETWEEN_CALLS_SECONDS,
OIDCCallback,
OIDCCallbackContext,
OIDCCallbackResult,
OIDCIdPInfo,
_OIDCProperties,
)
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE
if TYPE_CHECKING:
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.auth_shared import MongoCredential
_IS_SYNC = False
def _get_authenticator(
credentials: MongoCredential, address: tuple[str, int]
) -> _OIDCAuthenticator:
if credentials.cache.data:
return credentials.cache.data
# Extract values.
principal_name = credentials.username
properties = credentials.mechanism_properties
# Validate that the address is allowed.
if not properties.environment:
found = False
allowed_hosts = properties.allowed_hosts
for patt in allowed_hosts:
if patt == address[0]:
found = True
elif patt.startswith("*.") and address[0].endswith(patt[1:]):
found = True
if not found:
raise ConfigurationError(
f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}"
)
# Get or create the cache data.
credentials.cache.data = _OIDCAuthenticator(username=principal_name, properties=properties)
return credentials.cache.data
@dataclass
class _OIDCAuthenticator:
username: str
properties: _OIDCProperties
refresh_token: Optional[str] = field(default=None)
access_token: Optional[str] = field(default=None)
idp_info: Optional[OIDCIdPInfo] = field(default=None)
token_gen_id: int = field(default=0)
lock: threading.Lock = field(default_factory=threading.Lock)
last_call_time: float = field(default=0)
async def reauthenticate(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]:
"""Handle a reauthenticate from the server."""
# Invalidate the token for the connection.
self._invalidate(conn)
# Call the appropriate auth logic for the callback type.
if self.properties.callback:
return await self._authenticate_machine(conn)
return await self._authenticate_human(conn)
async def authenticate(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]:
"""Handle an initial authenticate request."""
# First handle speculative auth.
# If it succeeded, we are done.
ctx = conn.auth_ctx
if ctx and ctx.speculate_succeeded():
resp = ctx.speculative_authenticate
if resp and resp["done"]:
conn.oidc_token_gen_id = self.token_gen_id
return resp
# If spec auth failed, call the appropriate auth logic for the callback type.
# We cannot assume that the token is invalid, because a proxy may have been
# involved that stripped the speculative auth information.
if self.properties.callback:
return await self._authenticate_machine(conn)
return await self._authenticate_human(conn)
def get_spec_auth_cmd(self) -> Optional[MutableMapping[str, Any]]:
"""Get the appropriate speculative auth command."""
if not self.access_token:
return None
return self._get_start_command({"jwt": self.access_token})
async def _authenticate_machine(self, conn: AsyncConnection) -> Mapping[str, Any]:
# If there is a cached access token, try to authenticate with it. If
# authentication fails with error code 18, invalidate the access token,
# fetch a new access token, and try to authenticate again. If authentication
# fails for any other reason, raise the error to the user.
if self.access_token:
try:
return await self._sasl_start_jwt(conn)
except OperationFailure as e:
if self._is_auth_error(e):
return await self._authenticate_machine(conn)
raise
return await self._sasl_start_jwt(conn)
async def _authenticate_human(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]:
# If we have a cached access token, try a JwtStepRequest.
# authentication fails with error code 18, invalidate the access token,
# and try to authenticate again. If authentication fails for any other
# reason, raise the error to the user.
if self.access_token:
try:
return await self._sasl_start_jwt(conn)
except OperationFailure as e:
if self._is_auth_error(e):
return await self._authenticate_human(conn)
raise
# If we have a cached refresh token, try a JwtStepRequest with that.
# If authentication fails with error code 18, invalidate the access and
# refresh tokens, and try to authenticate again. If authentication fails for
# any other reason, raise the error to the user.
if self.refresh_token:
try:
return await self._sasl_start_jwt(conn)
except OperationFailure as e:
if self._is_auth_error(e):
self.refresh_token = None
return await self._authenticate_human(conn)
raise
# Start a new Two-Step SASL conversation.
# Run a PrincipalStepRequest to get the IdpInfo.
cmd = self._get_start_command(None)
start_resp = await self._run_command(conn, cmd)
# Attempt to authenticate with a JwtStepRequest.
return await self._sasl_continue_jwt(conn, start_resp)
def _get_access_token(self) -> Optional[str]:
properties = self.properties
cb: Union[None, OIDCCallback]
resp: OIDCCallbackResult
is_human = properties.human_callback is not None
if is_human and self.idp_info is None:
return None
if properties.callback:
cb = properties.callback
if properties.human_callback:
cb = properties.human_callback
prev_token = self.access_token
if prev_token:
return prev_token
if cb is None and not prev_token:
return None
if not prev_token and cb is not None:
with self.lock:
# See if the token was changed while we were waiting for the
# lock.
new_token = self.access_token
if new_token != prev_token:
return new_token
# Ensure that we are waiting a min time between callback invocations.
delta = time.time() - self.last_call_time
if delta < TIME_BETWEEN_CALLS_SECONDS:
time.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
self.last_call_time = time.time()
if is_human:
timeout = HUMAN_CALLBACK_TIMEOUT_SECONDS
assert self.idp_info is not None
else:
timeout = int(remaining() or MACHINE_CALLBACK_TIMEOUT_SECONDS)
context = OIDCCallbackContext(
timeout_seconds=timeout,
version=CALLBACK_VERSION,
refresh_token=self.refresh_token,
idp_info=self.idp_info,
username=self.properties.username,
)
resp = cb.fetch(context)
if not isinstance(resp, OIDCCallbackResult):
raise ValueError("Callback result must be of type OIDCCallbackResult")
self.refresh_token = resp.refresh_token
self.access_token = resp.access_token
self.token_gen_id += 1
return self.access_token
async def _run_command(
self, conn: AsyncConnection, cmd: MutableMapping[str, Any]
) -> Mapping[str, Any]:
try:
return await conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
except OperationFailure as e:
if self._is_auth_error(e):
self._invalidate(conn)
raise
def _is_auth_error(self, err: Exception) -> bool:
if not isinstance(err, OperationFailure):
return False
return err.code == _AUTHENTICATION_FAILURE_CODE
def _invalidate(self, conn: AsyncConnection) -> None:
# Ignore the invalidation if a token gen id is given and is less than our
# current token gen id.
token_gen_id = conn.oidc_token_gen_id or 0
if token_gen_id is not None and token_gen_id < self.token_gen_id:
return
self.access_token = None
async def _sasl_continue_jwt(
self, conn: AsyncConnection, start_resp: Mapping[str, Any]
) -> Mapping[str, Any]:
self.access_token = None
self.refresh_token = None
start_payload: dict = bson.decode(start_resp["payload"])
if "issuer" in start_payload:
self.idp_info = OIDCIdPInfo(**start_payload)
access_token = self._get_access_token()
conn.oidc_token_gen_id = self.token_gen_id
cmd = self._get_continue_command({"jwt": access_token}, start_resp)
return await self._run_command(conn, cmd)
async def _sasl_start_jwt(self, conn: AsyncConnection) -> Mapping[str, Any]:
access_token = self._get_access_token()
conn.oidc_token_gen_id = self.token_gen_id
cmd = self._get_start_command({"jwt": access_token})
return await self._run_command(conn, cmd)
def _get_start_command(self, payload: Optional[Mapping[str, Any]]) -> MutableMapping[str, Any]:
if payload is None:
principal_name = self.username
if principal_name:
payload = {"n": principal_name}
else:
payload = {}
bin_payload = Binary(bson.encode(payload))
return {"saslStart": 1, "mechanism": "MONGODB-OIDC", "payload": bin_payload}
def _get_continue_command(
self, payload: Mapping[str, Any], start_resp: Mapping[str, Any]
) -> MutableMapping[str, Any]:
bin_payload = Binary(bson.encode(payload))
return {
"saslContinue": 1,
"payload": bin_payload,
"conversationId": start_resp["conversationId"],
}
async def _authenticate_oidc(
credentials: MongoCredential, conn: AsyncConnection, reauthenticate: bool
) -> Optional[Mapping[str, Any]]:
"""Authenticate using MONGODB-OIDC."""
authenticator = _get_authenticator(credentials, conn.address)
if reauthenticate:
return await authenticator.reauthenticate(conn)
else:
return await authenticator.authenticate(conn)

View File

@@ -0,0 +1,738 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The bulk write operations interface.
.. versionadded:: 2.7
"""
from __future__ import annotations
import copy
import datetime
import logging
from collections.abc import MutableMapping
from itertools import islice
from typing import (
TYPE_CHECKING,
Any,
Iterator,
Mapping,
Optional,
Type,
Union,
)
from bson.objectid import ObjectId
from bson.raw_bson import RawBSONDocument
from pymongo import _csot, common
from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.bulk_shared import (
_COMMANDS,
_DELETE_ALL,
_merge_command,
_raise_bulk_write_error,
_Run,
)
from pymongo.common import (
validate_is_document_type,
validate_ok_for_replace,
validate_ok_for_update,
)
from pymongo.errors import (
ConfigurationError,
InvalidOperation,
NotPrimaryError,
OperationFailure,
)
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import (
_DELETE,
_INSERT,
_UPDATE,
_BulkWriteContext,
_convert_exception,
_convert_write_result,
_EncryptedBulkWriteContext,
_randint,
)
from pymongo.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
_IS_SYNC = False
class _AsyncBulk:
"""The private guts of the bulk write API."""
def __init__(
self,
collection: AsyncCollection[_DocumentType],
ordered: bool,
bypass_document_validation: bool,
comment: Optional[str] = None,
let: Optional[Any] = None,
) -> None:
"""Initialize a _AsyncBulk instance."""
self.collection = collection.with_options(
codec_options=collection.codec_options._replace(
unicode_decode_error_handler="replace", document_class=dict
)
)
self.let = let
if self.let is not None:
common.validate_is_document_type("let", self.let)
self.comment: Optional[str] = comment
self.ordered = ordered
self.ops: list[tuple[int, Mapping[str, Any]]] = []
self.executed = False
self.bypass_doc_val = bypass_document_validation
self.uses_collation = False
self.uses_array_filters = False
self.uses_hint_update = False
self.uses_hint_delete = False
self.is_retryable = True
self.retrying = False
self.started_retryable_write = False
# Extra state so that we know where to pick up on a retry attempt.
self.current_run = None
self.next_run = None
self.is_encrypted = False
@property
def bulk_ctx_class(self) -> Type[_BulkWriteContext]:
encrypter = self.collection.database.client._encrypter
if encrypter and not encrypter._bypass_auto_encryption:
self.is_encrypted = True
return _EncryptedBulkWriteContext
else:
self.is_encrypted = False
return _BulkWriteContext
def add_insert(self, document: _DocumentOut) -> None:
"""Add an insert document to the list of ops."""
validate_is_document_type("document", document)
# Generate ObjectId client side.
if not (isinstance(document, RawBSONDocument) or "_id" in document):
document["_id"] = ObjectId()
self.ops.append((_INSERT, document))
def add_update(
self,
selector: Mapping[str, Any],
update: Union[Mapping[str, Any], _Pipeline],
multi: bool = False,
upsert: bool = False,
collation: Optional[Mapping[str, Any]] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Union[str, dict[str, Any], None] = None,
) -> None:
"""Create an update document and add it to the list of ops."""
validate_ok_for_update(update)
cmd: dict[str, Any] = dict( # noqa: C406
[("q", selector), ("u", update), ("multi", multi), ("upsert", upsert)]
)
if collation is not None:
self.uses_collation = True
cmd["collation"] = collation
if array_filters is not None:
self.uses_array_filters = True
cmd["arrayFilters"] = array_filters
if hint is not None:
self.uses_hint_update = True
cmd["hint"] = hint
if multi:
# A bulk_write containing an update_many is not retryable.
self.is_retryable = False
self.ops.append((_UPDATE, cmd))
def add_replace(
self,
selector: Mapping[str, Any],
replacement: Mapping[str, Any],
upsert: bool = False,
collation: Optional[Mapping[str, Any]] = None,
hint: Union[str, dict[str, Any], None] = None,
) -> None:
"""Create a replace document and add it to the list of ops."""
validate_ok_for_replace(replacement)
cmd = {"q": selector, "u": replacement, "multi": False, "upsert": upsert}
if collation is not None:
self.uses_collation = True
cmd["collation"] = collation
if hint is not None:
self.uses_hint_update = True
cmd["hint"] = hint
self.ops.append((_UPDATE, cmd))
def add_delete(
self,
selector: Mapping[str, Any],
limit: int,
collation: Optional[Mapping[str, Any]] = None,
hint: Union[str, dict[str, Any], None] = None,
) -> None:
"""Create a delete document and add it to the list of ops."""
cmd = {"q": selector, "limit": limit}
if collation is not None:
self.uses_collation = True
cmd["collation"] = collation
if hint is not None:
self.uses_hint_delete = True
cmd["hint"] = hint
if limit == _DELETE_ALL:
# A bulk_write containing a delete_many is not retryable.
self.is_retryable = False
self.ops.append((_DELETE, cmd))
def gen_ordered(self) -> Iterator[Optional[_Run]]:
"""Generate batches of operations, batched by type of
operation, in the order **provided**.
"""
run = None
for idx, (op_type, operation) in enumerate(self.ops):
if run is None:
run = _Run(op_type)
elif run.op_type != op_type:
yield run
run = _Run(op_type)
run.add(idx, operation)
yield run
def gen_unordered(self) -> Iterator[_Run]:
"""Generate batches of operations, batched by type of
operation, in arbitrary order.
"""
operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)]
for idx, (op_type, operation) in enumerate(self.ops):
operations[op_type].add(idx, operation)
for run in operations:
if run.ops:
yield run
@_handle_reauth
async def write_command(
self,
bwc: _BulkWriteContext,
cmd: MutableMapping[str, Any],
request_id: int,
msg: bytes,
docs: list[Mapping[str, Any]],
client: AsyncMongoClient,
) -> dict[str, Any]:
"""A proxy for SocketInfo.write_command that handles event publishing."""
cmd[bwc.field] = docs
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.STARTED,
command=cmd,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._start(cmd, request_id, docs)
try:
reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc]
duration = datetime.datetime.now() - bwc.start_time
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.SUCCEEDED,
durationMS=duration,
reply=reply,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
await client._process_response(reply, bwc.session) # type: ignore[arg-type]
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, (NotPrimaryError, OperationFailure)):
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.FAILED,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if bwc.publish:
bwc._fail(request_id, failure, duration)
# Process the response from the server.
if isinstance(exc, (NotPrimaryError, OperationFailure)):
await client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
raise
return reply # type: ignore[return-value]
async def unack_write(
self,
bwc: _BulkWriteContext,
cmd: MutableMapping[str, Any],
request_id: int,
msg: bytes,
max_doc_size: int,
docs: list[Mapping[str, Any]],
client: AsyncMongoClient,
) -> Optional[Mapping[str, Any]]:
"""A proxy for AsyncConnection.unack_write that handles event publishing."""
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.STARTED,
command=cmd,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
cmd = bwc._start(cmd, request_id, docs)
try:
result = await bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override]
duration = datetime.datetime.now() - bwc.start_time
if result is not None:
reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type]
else:
# Comply with APM spec.
reply = {"ok": 1}
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.SUCCEEDED,
durationMS=duration,
reply=reply,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._succeed(request_id, reply, duration)
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, OperationFailure):
failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type]
elif isinstance(exc, NotPrimaryError):
failure = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.FAILED,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if bwc.publish:
assert bwc.start_time is not None
bwc._fail(request_id, failure, duration)
raise
return result # type: ignore[return-value]
async def _execute_batch_unack(
self,
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
cmd: dict[str, Any],
ops: list[Mapping[str, Any]],
client: AsyncMongoClient,
) -> list[Mapping[str, Any]]:
if self.is_encrypted:
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
await bwc.conn.command( # type: ignore[misc]
bwc.db_name,
batched_cmd, # type: ignore[arg-type]
write_concern=WriteConcern(w=0),
session=bwc.session, # type: ignore[arg-type]
client=client, # type: ignore[arg-type]
)
else:
request_id, msg, to_send = bwc.batch_command(cmd, ops)
# Though this isn't strictly a "legacy" write, the helper
# handles publishing commands and sending our message
# without receiving a result. Send 0 for max_doc_size
# to disable size checking. Size checking is handled while
# the documents are encoded to BSON.
await self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) # type: ignore[arg-type]
return to_send
async def _execute_batch(
self,
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
cmd: dict[str, Any],
ops: list[Mapping[str, Any]],
client: AsyncMongoClient,
) -> tuple[dict[str, Any], list[Mapping[str, Any]]]:
if self.is_encrypted:
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
result = await bwc.conn.command( # type: ignore[misc]
bwc.db_name,
batched_cmd, # type: ignore[arg-type]
codec_options=bwc.codec,
session=bwc.session, # type: ignore[arg-type]
client=client, # type: ignore[arg-type]
)
else:
request_id, msg, to_send = bwc.batch_command(cmd, ops)
result = await self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type]
return result, to_send # type: ignore[return-value]
async def _execute_command(
self,
generator: Iterator[Any],
write_concern: WriteConcern,
session: Optional[AsyncClientSession],
conn: AsyncConnection,
op_id: int,
retryable: bool,
full_result: MutableMapping[str, Any],
final_write_concern: Optional[WriteConcern] = None,
) -> None:
db_name = self.collection.database.name
client = self.collection.database.client
listeners = client._event_listeners
if not self.current_run:
self.current_run = next(generator)
self.next_run = None
run = self.current_run
# AsyncConnection.command validates the session, but we use
# AsyncConnection.write_command
conn.validate_session(client, session)
last_run = False
while run:
if not self.retrying:
self.next_run = next(generator, None)
if self.next_run is None:
last_run = True
cmd_name = _COMMANDS[run.op_type]
bwc = self.bulk_ctx_class(
db_name,
cmd_name,
conn,
op_id,
listeners,
session,
run.op_type,
self.collection.codec_options,
)
while run.idx_offset < len(run.ops):
# If this is the last possible operation, use the
# final write concern.
if last_run and (len(run.ops) - run.idx_offset) == 1:
write_concern = final_write_concern or write_concern
cmd = {cmd_name: self.collection.name, "ordered": self.ordered}
if self.comment:
cmd["comment"] = self.comment
_csot.apply_write_concern(cmd, write_concern)
if self.bypass_doc_val:
cmd["bypassDocumentValidation"] = True
if self.let is not None and run.op_type in (_DELETE, _UPDATE):
cmd["let"] = self.let
if session:
# Start a new retryable write unless one was already
# started for this command.
if retryable and not self.started_retryable_write:
session._start_retryable_write()
self.started_retryable_write = True
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
conn.send_cluster_time(cmd, session, client)
conn.add_server_api(cmd)
# CSOT: apply timeout before encoding the command.
conn.apply_timeout(client, cmd)
ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible in one command.
if write_concern.acknowledged:
result, to_send = await self._execute_batch(bwc, cmd, ops, client)
# Retryable writeConcernErrors halt the execution of this run.
wce = result.get("writeConcernError", {})
if wce.get("code", 0) in _RETRYABLE_ERROR_CODES:
# Synthesize the full bulk result without modifying the
# current one because this write operation may be retried.
full = copy.deepcopy(full_result)
_merge_command(run, full, run.idx_offset, result)
_raise_bulk_write_error(full)
_merge_command(run, full_result, run.idx_offset, result)
# We're no longer in a retry once a command succeeds.
self.retrying = False
self.started_retryable_write = False
if self.ordered and "writeErrors" in result:
break
else:
to_send = await self._execute_batch_unack(bwc, cmd, ops, client)
run.idx_offset += len(to_send)
# We're supposed to continue if errors are
# at the write concern level (e.g. wtimeout)
if self.ordered and full_result["writeErrors"]:
break
# Reset our state
self.current_run = run = self.next_run
async def execute_command(
self,
generator: Iterator[Any],
write_concern: WriteConcern,
session: Optional[AsyncClientSession],
operation: str,
) -> dict[str, Any]:
"""Execute using write commands."""
# nModified is only reported for write commands, not legacy ops.
full_result = {
"writeErrors": [],
"writeConcernErrors": [],
"nInserted": 0,
"nUpserted": 0,
"nMatched": 0,
"nModified": 0,
"nRemoved": 0,
"upserted": [],
}
op_id = _randint()
async def retryable_bulk(
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable: bool
) -> None:
await self._execute_command(
generator,
write_concern,
session,
conn,
op_id,
retryable,
full_result,
)
client = self.collection.database.client
_ = await client._retryable_write(
self.is_retryable,
retryable_bulk,
session,
operation,
bulk=self, # type: ignore[arg-type]
operation_id=op_id,
)
if full_result["writeErrors"] or full_result["writeConcernErrors"]:
_raise_bulk_write_error(full_result)
return full_result
async def execute_op_msg_no_results(
self, conn: AsyncConnection, generator: Iterator[Any]
) -> None:
"""Execute write commands with OP_MSG and w=0 writeConcern, unordered."""
db_name = self.collection.database.name
client = self.collection.database.client
listeners = client._event_listeners
op_id = _randint()
if not self.current_run:
self.current_run = next(generator)
run = self.current_run
while run:
cmd_name = _COMMANDS[run.op_type]
bwc = self.bulk_ctx_class(
db_name,
cmd_name,
conn,
op_id,
listeners,
None,
run.op_type,
self.collection.codec_options,
)
while run.idx_offset < len(run.ops):
cmd = {
cmd_name: self.collection.name,
"ordered": False,
"writeConcern": {"w": 0},
}
conn.add_server_api(cmd)
ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible.
to_send = await self._execute_batch_unack(bwc, cmd, ops, client)
run.idx_offset += len(to_send)
self.current_run = run = next(generator, None)
async def execute_command_no_results(
self,
conn: AsyncConnection,
generator: Iterator[Any],
write_concern: WriteConcern,
) -> None:
"""Execute write commands with OP_MSG and w=0 WriteConcern, ordered."""
full_result = {
"writeErrors": [],
"writeConcernErrors": [],
"nInserted": 0,
"nUpserted": 0,
"nMatched": 0,
"nModified": 0,
"nRemoved": 0,
"upserted": [],
}
# Ordered bulk writes have to be acknowledged so that we stop
# processing at the first error, even when the application
# specified unacknowledged writeConcern.
initial_write_concern = WriteConcern()
op_id = _randint()
try:
await self._execute_command(
generator,
initial_write_concern,
None,
conn,
op_id,
False,
full_result,
write_concern,
)
except OperationFailure:
pass
async def execute_no_results(
self,
conn: AsyncConnection,
generator: Iterator[Any],
write_concern: WriteConcern,
) -> None:
"""Execute all operations, returning no results (w=0)."""
if self.uses_collation:
raise ConfigurationError("Collation is unsupported for unacknowledged writes.")
if self.uses_array_filters:
raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.")
# Guard against unsupported unacknowledged writes.
unack = write_concern and not write_concern.acknowledged
if unack and self.uses_hint_delete and conn.max_wire_version < 9:
raise ConfigurationError(
"Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands."
)
if unack and self.uses_hint_update and conn.max_wire_version < 8:
raise ConfigurationError(
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
)
# Cannot have both unacknowledged writes and bypass document validation.
if self.bypass_doc_val:
raise OperationFailure(
"Cannot set bypass_document_validation with unacknowledged write concern"
)
if self.ordered:
return await self.execute_command_no_results(conn, generator, write_concern)
return await self.execute_op_msg_no_results(conn, generator)
async def execute(
self,
write_concern: WriteConcern,
session: Optional[AsyncClientSession],
operation: str,
) -> Any:
"""Execute operations."""
if not self.ops:
raise InvalidOperation("No operations to execute")
if self.executed:
raise InvalidOperation("Bulk operations can only be executed once.")
self.executed = True
write_concern = write_concern or self.collection.write_concern
session = _validate_session_write_concern(session, write_concern)
if self.ordered:
generator = self.gen_ordered()
else:
generator = self.gen_unordered()
client = self.collection.database.client
if not write_concern.acknowledged:
async with await client._conn_for_writes(session, operation) as connection:
await self.execute_no_results(connection, generator, write_concern)
return None
else:
return await self.execute_command(generator, write_concern, session, operation)

View File

@@ -0,0 +1,498 @@
# Copyright 2017 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Watch changes on a collection, a database, or the entire cluster."""
from __future__ import annotations
import copy
from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union
from bson import CodecOptions, _bson_to_dict
from bson.raw_bson import RawBSONDocument
from bson.timestamp import Timestamp
from pymongo import _csot, common
from pymongo.asynchronous.aggregation import (
_AggregationCommand,
_CollectionAggregationCommand,
_DatabaseAggregationCommand,
)
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.collation import validate_collation_or_none
from pymongo.errors import (
ConnectionFailure,
CursorNotFound,
InvalidOperation,
OperationFailure,
PyMongoError,
)
from pymongo.operations import _Op
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
_IS_SYNC = False
# The change streams spec considers the following server errors from the
# getMore command non-resumable. All other getMore errors are resumable.
_RESUMABLE_GETMORE_ERRORS = frozenset(
[
6, # HostUnreachable
7, # HostNotFound
89, # NetworkTimeout
91, # ShutdownInProgress
189, # PrimarySteppedDown
262, # ExceededTimeLimit
9001, # SocketException
10107, # NotWritablePrimary
11600, # InterruptedAtShutdown
11602, # InterruptedDueToReplStateChange
13435, # NotPrimaryNoSecondaryOk
13436, # NotPrimaryOrSecondary
63, # StaleShardVersion
150, # StaleEpoch
13388, # StaleConfig
234, # RetryChangeStream
133, # FailedToSatisfyReadPreference
]
)
if TYPE_CHECKING:
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import AsyncConnection
def _resumable(exc: PyMongoError) -> bool:
"""Return True if given a resumable change stream error."""
if isinstance(exc, (ConnectionFailure, CursorNotFound)):
return True
if isinstance(exc, OperationFailure):
if exc._max_wire_version is None:
return False
return (
exc._max_wire_version >= 9 and exc.has_error_label("ResumableChangeStreamError")
) or (exc._max_wire_version < 9 and exc.code in _RESUMABLE_GETMORE_ERRORS)
return False
class AsyncChangeStream(Generic[_DocumentType]):
"""The internal abstract base class for change stream cursors.
Should not be called directly by application developers. Use
:meth:`pymongo.asynchronous.collection.AsyncCollection.watch`,
:meth:`pymongo.asynchronous.database.AsyncDatabase.watch`, or
:meth:`pymongo.asynchronous.mongo_client.AsyncMongoClient.watch` instead.
.. versionadded:: 3.6
.. seealso:: The MongoDB documentation on `changeStreams <https://mongodb.com/docs/manual/changeStreams/>`_.
"""
def __init__(
self,
target: Union[
AsyncMongoClient[_DocumentType],
AsyncDatabase[_DocumentType],
AsyncCollection[_DocumentType],
],
pipeline: Optional[_Pipeline],
full_document: Optional[str],
resume_after: Optional[Mapping[str, Any]],
max_await_time_ms: Optional[int],
batch_size: Optional[int],
collation: Optional[_CollationIn],
start_at_operation_time: Optional[Timestamp],
session: Optional[AsyncClientSession],
start_after: Optional[Mapping[str, Any]],
comment: Optional[Any] = None,
full_document_before_change: Optional[str] = None,
show_expanded_events: Optional[bool] = None,
) -> None:
if pipeline is None:
pipeline = []
pipeline = common.validate_list("pipeline", pipeline)
common.validate_string_or_none("full_document", full_document)
validate_collation_or_none(collation)
common.validate_non_negative_integer_or_none("batchSize", batch_size)
self._decode_custom = False
self._orig_codec_options: CodecOptions[_DocumentType] = target.codec_options
if target.codec_options.type_registry._decoder_map:
self._decode_custom = True
# Keep the type registry so that we support encoding custom types
# in the pipeline.
self._target = target.with_options( # type: ignore
codec_options=target.codec_options.with_options(document_class=RawBSONDocument)
)
else:
self._target = target
self._pipeline = copy.deepcopy(pipeline)
self._full_document = full_document
self._full_document_before_change = full_document_before_change
self._uses_start_after = start_after is not None
self._uses_resume_after = resume_after is not None
self._resume_token = copy.deepcopy(start_after or resume_after)
self._max_await_time_ms = max_await_time_ms
self._batch_size = batch_size
self._collation = collation
self._start_at_operation_time = start_at_operation_time
self._session = session
self._comment = comment
self._closed = False
self._timeout = self._target._timeout
self._show_expanded_events = show_expanded_events
async def _initialize_cursor(self) -> None:
# Initialize cursor.
self._cursor = await self._create_cursor()
@property
def _aggregation_command_class(self) -> Type[_AggregationCommand]:
"""The aggregation command class to be used."""
raise NotImplementedError
@property
def _client(self) -> AsyncMongoClient:
"""The client against which the aggregation commands for
this AsyncChangeStream will be run.
"""
raise NotImplementedError
def _change_stream_options(self) -> dict[str, Any]:
"""Return the options dict for the $changeStream pipeline stage."""
options: dict[str, Any] = {}
if self._full_document is not None:
options["fullDocument"] = self._full_document
if self._full_document_before_change is not None:
options["fullDocumentBeforeChange"] = self._full_document_before_change
resume_token = self.resume_token
if resume_token is not None:
if self._uses_start_after:
options["startAfter"] = resume_token
else:
options["resumeAfter"] = resume_token
elif self._start_at_operation_time is not None:
options["startAtOperationTime"] = self._start_at_operation_time
if self._show_expanded_events:
options["showExpandedEvents"] = self._show_expanded_events
return options
def _command_options(self) -> dict[str, Any]:
"""Return the options dict for the aggregation command."""
options = {}
if self._max_await_time_ms is not None:
options["maxAwaitTimeMS"] = self._max_await_time_ms
if self._batch_size is not None:
options["batchSize"] = self._batch_size
return options
def _aggregation_pipeline(self) -> list[dict[str, Any]]:
"""Return the full aggregation pipeline for this AsyncChangeStream."""
options = self._change_stream_options()
full_pipeline: list = [{"$changeStream": options}]
full_pipeline.extend(self._pipeline)
return full_pipeline
def _process_result(self, result: Mapping[str, Any], conn: AsyncConnection) -> None:
"""Callback that caches the postBatchResumeToken or
startAtOperationTime from a changeStream aggregate command response
containing an empty batch of change documents.
This is implemented as a callback because we need access to the wire
version in order to determine whether to cache this value.
"""
if not result["cursor"]["firstBatch"]:
if "postBatchResumeToken" in result["cursor"]:
self._resume_token = result["cursor"]["postBatchResumeToken"]
elif (
self._start_at_operation_time is None
and self._uses_resume_after is False
and self._uses_start_after is False
and conn.max_wire_version >= 7
):
self._start_at_operation_time = result.get("operationTime")
# PYTHON-2181: informative error on missing operationTime.
if self._start_at_operation_time is None:
raise OperationFailure(
"Expected field 'operationTime' missing from command "
f"response : {result!r}"
)
async def _run_aggregation_cmd(
self, session: Optional[AsyncClientSession], explicit_session: bool
) -> AsyncCommandCursor:
"""Run the full aggregation pipeline for this AsyncChangeStream and return
the corresponding AsyncCommandCursor.
"""
cmd = self._aggregation_command_class(
self._target,
AsyncCommandCursor,
self._aggregation_pipeline(),
self._command_options(),
explicit_session,
result_processor=self._process_result,
comment=self._comment,
)
return await self._client._retryable_read(
cmd.get_cursor,
self._target._read_preference_for(session),
session,
operation=_Op.AGGREGATE,
)
async def _create_cursor(self) -> AsyncCommandCursor:
async with self._client._tmp_session(self._session, close=False) as s:
return await self._run_aggregation_cmd(
session=s, explicit_session=self._session is not None
)
async def _resume(self) -> None:
"""Reestablish this change stream after a resumable error."""
try:
await self._cursor.close()
except PyMongoError:
pass
self._cursor = await self._create_cursor()
async def close(self) -> None:
"""Close this AsyncChangeStream."""
self._closed = True
await self._cursor.close()
def __aiter__(self) -> AsyncChangeStream[_DocumentType]:
return self
@property
def resume_token(self) -> Optional[Mapping[str, Any]]:
"""The cached resume token that will be used to resume after the most
recently returned change.
.. versionadded:: 3.9
"""
return copy.deepcopy(self._resume_token)
@_csot.apply
async def next(self) -> _DocumentType:
"""Advance the cursor.
This method blocks until the next change document is returned or an
unrecoverable error is raised. This method is used when iterating over
all changes in the cursor. For example::
try:
resume_token = None
pipeline = [{'$match': {'operationType': 'insert'}}]
async with await db.collection.watch(pipeline) as stream:
async for insert_change in stream:
print(insert_change)
resume_token = stream.resume_token
except pymongo.errors.PyMongoError:
# The AsyncChangeStream encountered an unrecoverable error or the
# resume attempt failed to recreate the cursor.
if resume_token is None:
# There is no usable resume token because there was a
# failure during AsyncChangeStream initialization.
logging.error('...')
else:
# Use the interrupted AsyncChangeStream's resume token to create
# a new AsyncChangeStream. The new stream will continue from the
# last seen insert change without missing any events.
async with await db.collection.watch(
pipeline, resume_after=resume_token) as stream:
async for insert_change in stream:
print(insert_change)
Raises :exc:`StopIteration` if this AsyncChangeStream is closed.
"""
while self.alive:
doc = await self.try_next()
if doc is not None:
return doc
raise StopAsyncIteration
__anext__ = next
@property
def alive(self) -> bool:
"""Does this cursor have the potential to return more data?
.. note:: Even if :attr:`alive` is ``True``, :meth:`next` can raise
:exc:`StopIteration` and :meth:`try_next` can return ``None``.
.. versionadded:: 3.8
"""
return not self._closed
@_csot.apply
async def try_next(self) -> Optional[_DocumentType]:
"""Advance the cursor without blocking indefinitely.
This method returns the next change document without waiting
indefinitely for the next change. For example::
async with await db.collection.watch() as stream:
while stream.alive:
change = await stream.try_next()
# Note that the AsyncChangeStream's resume token may be updated
# even when no changes are returned.
print("Current resume token: %r" % (stream.resume_token,))
if change is not None:
print("Change document: %r" % (change,))
continue
# We end up here when there are no recent changes.
# Sleep for a while before trying again to avoid flooding
# the server with getMore requests when no changes are
# available.
asyncio.sleep(10)
If no change document is cached locally then this method runs a single
getMore command. If the getMore yields any documents, the next
document is returned, otherwise, if the getMore returns no documents
(because there have been no changes) then ``None`` is returned.
:return: The next change document or ``None`` when no document is available
after running a single getMore or when the cursor is closed.
.. versionadded:: 3.8
"""
if not self._closed and not self._cursor.alive:
await self._resume()
# Attempt to get the next change with at most one getMore and at most
# one resume attempt.
try:
try:
change = await self._cursor._try_next(True)
except PyMongoError as exc:
if not _resumable(exc):
raise
await self._resume()
change = await self._cursor._try_next(False)
except PyMongoError as exc:
# Close the stream after a fatal error.
if not _resumable(exc) and not exc.timeout:
await self.close()
raise
except Exception:
await self.close()
raise
# Check if the cursor was invalidated.
if not self._cursor.alive:
self._closed = True
# If no changes are available.
if change is None:
# We have either iterated over all documents in the cursor,
# OR the most-recently returned batch is empty. In either case,
# update the cached resume token with the postBatchResumeToken if
# one was returned. We also clear the startAtOperationTime.
if self._cursor._post_batch_resume_token is not None:
self._resume_token = self._cursor._post_batch_resume_token
self._start_at_operation_time = None
return change
# Else, changes are available.
try:
resume_token = change["_id"]
except KeyError:
await self.close()
raise InvalidOperation(
"Cannot provide resume functionality when the resume token is missing."
) from None
# If this is the last change document from the current batch, cache the
# postBatchResumeToken.
if not self._cursor._has_next() and self._cursor._post_batch_resume_token:
resume_token = self._cursor._post_batch_resume_token
# Hereafter, don't use startAfter; instead use resumeAfter.
self._uses_start_after = False
self._uses_resume_after = True
# Cache the resume token and clear startAtOperationTime.
self._resume_token = resume_token
self._start_at_operation_time = None
if self._decode_custom:
return _bson_to_dict(change.raw, self._orig_codec_options)
return change
async def __aenter__(self) -> AsyncChangeStream[_DocumentType]:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()
class AsyncCollectionChangeStream(AsyncChangeStream[_DocumentType]):
"""A change stream that watches changes on a single collection.
Should not be called directly by application developers. Use
helper method :meth:`pymongo.asynchronous.collection.AsyncCollection.watch` instead.
.. versionadded:: 3.7
"""
_target: AsyncCollection[_DocumentType]
@property
def _aggregation_command_class(self) -> Type[_CollectionAggregationCommand]:
return _CollectionAggregationCommand
@property
def _client(self) -> AsyncMongoClient[_DocumentType]:
return self._target.database.client
class AsyncDatabaseChangeStream(AsyncChangeStream[_DocumentType]):
"""A change stream that watches changes on all collections in a database.
Should not be called directly by application developers. Use
helper method :meth:`pymongo.asynchronous.database.AsyncDatabase.watch` instead.
.. versionadded:: 3.7
"""
_target: AsyncDatabase[_DocumentType]
@property
def _aggregation_command_class(self) -> Type[_DatabaseAggregationCommand]:
return _DatabaseAggregationCommand
@property
def _client(self) -> AsyncMongoClient[_DocumentType]:
return self._target.client
class AsyncClusterChangeStream(AsyncDatabaseChangeStream[_DocumentType]):
"""A change stream that watches changes on all collections in the cluster.
Should not be called directly by application developers. Use
helper method :meth:`pymongo.asynchronous.mongo_client.AsyncMongoClient.watch` instead.
.. versionadded:: 3.7
"""
def _change_stream_options(self) -> dict[str, Any]:
options = super()._change_stream_options()
options["allChangesForCluster"] = True
return options

View File

@@ -0,0 +1,800 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The client-level bulk write operations interface.
.. versionadded:: 4.9
"""
from __future__ import annotations
import copy
import datetime
import logging
from collections.abc import MutableMapping
from itertools import islice
from typing import (
TYPE_CHECKING,
Any,
Mapping,
Optional,
Type,
Union,
)
from bson.objectid import ObjectId
from bson.raw_bson import RawBSONDocument
from pymongo import _csot, common
from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.helpers import _handle_reauth
if TYPE_CHECKING:
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import AsyncConnection
from pymongo._client_bulk_shared import (
_merge_command,
_throw_client_bulk_write_exception,
)
from pymongo.common import (
validate_is_document_type,
validate_ok_for_replace,
validate_ok_for_update,
)
from pymongo.errors import (
ConfigurationError,
ConnectionFailure,
InvalidOperation,
NotPrimaryError,
OperationFailure,
WaitQueueTimeoutError,
)
from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import (
_ClientBulkWriteContext,
_convert_client_bulk_exception,
_convert_exception,
_convert_write_result,
_randint,
)
from pymongo.read_preferences import ReadPreference
from pymongo.results import (
ClientBulkWriteResult,
DeleteResult,
InsertOneResult,
UpdateResult,
)
from pymongo.typings import _DocumentOut, _Pipeline
from pymongo.write_concern import WriteConcern
_IS_SYNC = False
class _AsyncClientBulk:
"""The private guts of the client-level bulk write API."""
def __init__(
self,
client: AsyncMongoClient,
write_concern: WriteConcern,
ordered: bool = True,
bypass_document_validation: Optional[bool] = None,
comment: Optional[str] = None,
let: Optional[Any] = None,
verbose_results: bool = False,
) -> None:
"""Initialize a _AsyncClientBulk instance."""
self.client = client
self.write_concern = write_concern
self.let = let
if self.let is not None:
common.validate_is_document_type("let", self.let)
self.ordered = ordered
self.bypass_doc_val = bypass_document_validation
self.comment = comment
self.verbose_results = verbose_results
self.ops: list[tuple[str, Mapping[str, Any]]] = []
self.namespaces: list[str] = []
self.idx_offset: int = 0
self.total_ops: int = 0
self.executed = False
self.uses_upsert = False
self.uses_collation = False
self.uses_array_filters = False
self.uses_hint_update = False
self.uses_hint_delete = False
self.is_retryable = self.client.options.retry_writes
self.retrying = False
self.started_retryable_write = False
@property
def bulk_ctx_class(self) -> Type[_ClientBulkWriteContext]:
return _ClientBulkWriteContext
def add_insert(self, namespace: str, document: _DocumentOut) -> None:
"""Add an insert document to the list of ops."""
validate_is_document_type("document", document)
# Generate ObjectId client side.
if not (isinstance(document, RawBSONDocument) or "_id" in document):
document["_id"] = ObjectId()
cmd = {"insert": -1, "document": document}
self.ops.append(("insert", cmd))
self.namespaces.append(namespace)
self.total_ops += 1
def add_update(
self,
namespace: str,
selector: Mapping[str, Any],
update: Union[Mapping[str, Any], _Pipeline],
multi: bool = False,
upsert: Optional[bool] = None,
collation: Optional[Mapping[str, Any]] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Union[str, dict[str, Any], None] = None,
) -> None:
"""Create an update document and add it to the list of ops."""
validate_ok_for_update(update)
cmd = {
"update": -1,
"filter": selector,
"updateMods": update,
"multi": multi,
}
if upsert is not None:
self.uses_upsert = True
cmd["upsert"] = upsert
if array_filters is not None:
self.uses_array_filters = True
cmd["arrayFilters"] = array_filters
if hint is not None:
self.uses_hint_update = True
cmd["hint"] = hint
if collation is not None:
self.uses_collation = True
cmd["collation"] = collation
if multi:
# A bulk_write containing an update_many is not retryable.
self.is_retryable = False
self.ops.append(("update", cmd))
self.namespaces.append(namespace)
self.total_ops += 1
def add_replace(
self,
namespace: str,
selector: Mapping[str, Any],
replacement: Mapping[str, Any],
upsert: Optional[bool] = None,
collation: Optional[Mapping[str, Any]] = None,
hint: Union[str, dict[str, Any], None] = None,
) -> None:
"""Create a replace document and add it to the list of ops."""
validate_ok_for_replace(replacement)
cmd = {
"update": -1,
"filter": selector,
"updateMods": replacement,
"multi": False,
}
if upsert is not None:
self.uses_upsert = True
cmd["upsert"] = upsert
if hint is not None:
self.uses_hint_update = True
cmd["hint"] = hint
if collation is not None:
self.uses_collation = True
cmd["collation"] = collation
self.ops.append(("replace", cmd))
self.namespaces.append(namespace)
self.total_ops += 1
def add_delete(
self,
namespace: str,
selector: Mapping[str, Any],
multi: bool,
collation: Optional[Mapping[str, Any]] = None,
hint: Union[str, dict[str, Any], None] = None,
) -> None:
"""Create a delete document and add it to the list of ops."""
cmd = {"delete": -1, "filter": selector, "multi": multi}
if hint is not None:
self.uses_hint_delete = True
cmd["hint"] = hint
if collation is not None:
self.uses_collation = True
cmd["collation"] = collation
if multi:
# A bulk_write containing an update_many is not retryable.
self.is_retryable = False
self.ops.append(("delete", cmd))
self.namespaces.append(namespace)
self.total_ops += 1
@_handle_reauth
async def write_command(
self,
bwc: _ClientBulkWriteContext,
cmd: MutableMapping[str, Any],
request_id: int,
msg: Union[bytes, dict[str, Any]],
op_docs: list[Mapping[str, Any]],
ns_docs: list[Mapping[str, Any]],
client: AsyncMongoClient,
) -> dict[str, Any]:
"""A proxy for AsyncConnection.write_command that handles event publishing."""
cmd["ops"] = op_docs
cmd["nsInfo"] = ns_docs
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.STARTED,
command=cmd,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._start(cmd, request_id, op_docs, ns_docs)
try:
reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type]
duration = datetime.datetime.now() - bwc.start_time
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.SUCCEEDED,
durationMS=duration,
reply=reply,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
# Process the response from the server.
await self.client._process_response(reply, bwc.session) # type: ignore[arg-type]
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, (NotPrimaryError, OperationFailure)):
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.FAILED,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if bwc.publish:
bwc._fail(request_id, failure, duration)
# Top-level error will be embedded in ClientBulkWriteException.
reply = {"error": exc}
# Process the response from the server.
if isinstance(exc, OperationFailure):
await self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
else:
await self.client._process_response({}, bwc.session) # type: ignore[arg-type]
return reply # type: ignore[return-value]
async def unack_write(
self,
bwc: _ClientBulkWriteContext,
cmd: MutableMapping[str, Any],
request_id: int,
msg: bytes,
op_docs: list[Mapping[str, Any]],
ns_docs: list[Mapping[str, Any]],
client: AsyncMongoClient,
) -> Optional[Mapping[str, Any]]:
"""A proxy for AsyncConnection.unack_write that handles event publishing."""
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.STARTED,
command=cmd,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
cmd = bwc._start(cmd, request_id, op_docs, ns_docs)
try:
result = await bwc.conn.unack_write(msg, bwc.max_bson_size) # type: ignore[func-returns-value, misc, override]
duration = datetime.datetime.now() - bwc.start_time
if result is not None:
reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type]
else:
# Comply with APM spec.
reply = {"ok": 1}
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.SUCCEEDED,
durationMS=duration,
reply=reply,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
)
if bwc.publish:
bwc._succeed(request_id, reply, duration)
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, OperationFailure):
failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type]
elif isinstance(exc, NotPrimaryError):
failure = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.FAILED,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=bwc.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=bwc.conn.id,
serverConnectionId=bwc.conn.server_connection_id,
serverHost=bwc.conn.address[0],
serverPort=bwc.conn.address[1],
serviceId=bwc.conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if bwc.publish:
assert bwc.start_time is not None
bwc._fail(request_id, failure, duration)
# Top-level error will be embedded in ClientBulkWriteException.
reply = {"error": exc}
return reply
async def _execute_batch_unack(
self,
bwc: _ClientBulkWriteContext,
cmd: dict[str, Any],
ops: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
) -> tuple[list[Mapping[str, Any]], list[Mapping[str, Any]]]:
"""Executes a batch of bulkWrite server commands (unack)."""
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces)
await self.unack_write(bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client) # type: ignore[arg-type]
return to_send_ops, to_send_ns
async def _execute_batch(
self,
bwc: _ClientBulkWriteContext,
cmd: dict[str, Any],
ops: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
) -> tuple[dict[str, Any], list[Mapping[str, Any]], list[Mapping[str, Any]]]:
"""Executes a batch of bulkWrite server commands (ack)."""
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces)
result = await self.write_command(
bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client
) # type: ignore[arg-type]
return result, to_send_ops, to_send_ns # type: ignore[return-value]
async def _process_results_cursor(
self,
full_result: MutableMapping[str, Any],
result: MutableMapping[str, Any],
conn: AsyncConnection,
session: Optional[AsyncClientSession],
) -> None:
"""Internal helper for processing the server reply command cursor."""
if result.get("cursor"):
coll = AsyncCollection(
database=AsyncDatabase(self.client, "admin"),
name="$cmd.bulkWrite",
)
cmd_cursor = AsyncCommandCursor(
coll,
result["cursor"],
conn.address,
session=session,
explicit_session=session is not None,
comment=self.comment,
)
await cmd_cursor._maybe_pin_connection(conn)
# Iterate the cursor to get individual write results.
try:
async for doc in cmd_cursor:
original_index = doc["idx"] + self.idx_offset
op_type, op = self.ops[original_index]
if not doc["ok"]:
result["writeErrors"].append(doc)
if self.ordered:
return
# Record individual write result.
if doc["ok"] and self.verbose_results:
if op_type == "insert":
inserted_id = op["document"]["_id"]
res = InsertOneResult(inserted_id, acknowledged=True) # type: ignore[assignment]
if op_type in ["update", "replace"]:
op_type = "update"
res = UpdateResult(doc, acknowledged=True, in_client_bulk=True) # type: ignore[assignment]
if op_type == "delete":
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
full_result[f"{op_type}Results"][original_index] = res
except Exception as exc:
# Attempt to close the cursor, then raise top-level error.
if cmd_cursor.alive:
await cmd_cursor.close()
result["error"] = _convert_client_bulk_exception(exc)
async def _execute_command(
self,
write_concern: WriteConcern,
session: Optional[AsyncClientSession],
conn: AsyncConnection,
op_id: int,
retryable: bool,
full_result: MutableMapping[str, Any],
final_write_concern: Optional[WriteConcern] = None,
) -> None:
"""Internal helper for executing batches of bulkWrite commands."""
db_name = "admin"
cmd_name = "bulkWrite"
listeners = self.client._event_listeners
# AsyncConnection.command validates the session, but we use
# AsyncConnection.write_command
conn.validate_session(self.client, session)
bwc = self.bulk_ctx_class(
db_name,
cmd_name,
conn,
op_id,
listeners, # type: ignore[arg-type]
session,
self.client.codec_options,
)
while self.idx_offset < self.total_ops:
# If this is the last possible batch, use the
# final write concern.
if self.total_ops - self.idx_offset <= bwc.max_write_batch_size:
write_concern = final_write_concern or write_concern
# Construct the server command, specifying the relevant options.
cmd = {"bulkWrite": 1}
cmd["errorsOnly"] = not self.verbose_results
cmd["ordered"] = self.ordered # type: ignore[assignment]
not_in_transaction = session and not session.in_transaction
if not_in_transaction or not session:
_csot.apply_write_concern(cmd, write_concern)
if self.bypass_doc_val is not None:
cmd["bypassDocumentValidation"] = self.bypass_doc_val
if self.comment:
cmd["comment"] = self.comment # type: ignore[assignment]
if self.let:
cmd["let"] = self.let
if session:
# Start a new retryable write unless one was already
# started for this command.
if retryable and not self.started_retryable_write:
session._start_retryable_write()
self.started_retryable_write = True
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
conn.send_cluster_time(cmd, session, self.client)
conn.add_server_api(cmd)
# CSOT: apply timeout before encoding the command.
conn.apply_timeout(self.client, cmd)
ops = islice(self.ops, self.idx_offset, None)
namespaces = islice(self.namespaces, self.idx_offset, None)
# Run as many ops as possible in one server command.
if write_concern.acknowledged:
raw_result, to_send_ops, _ = await self._execute_batch(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
result = raw_result
# Top-level server/network error.
if result.get("error"):
error = result["error"]
retryable_top_level_error = (
hasattr(error, "details")
and isinstance(error.details, dict)
and error.details.get("code", 0) in _RETRYABLE_ERROR_CODES
)
retryable_network_error = isinstance(
error, ConnectionFailure
) and not isinstance(error, (NotPrimaryError, WaitQueueTimeoutError))
# Synthesize the full bulk result without modifying the
# current one because this write operation may be retried.
if retryable and (retryable_top_level_error or retryable_network_error):
full = copy.deepcopy(full_result)
_merge_command(self.ops, self.idx_offset, full, result)
_throw_client_bulk_write_exception(full, self.verbose_results)
else:
_merge_command(self.ops, self.idx_offset, full_result, result)
_throw_client_bulk_write_exception(full_result, self.verbose_results)
result["error"] = None
result["writeErrors"] = []
if result.get("nErrors", 0) < len(to_send_ops):
full_result["anySuccessful"] = True
# Top-level command error.
if not result["ok"]:
result["error"] = raw_result
_merge_command(self.ops, self.idx_offset, full_result, result)
break
if retryable:
# Retryable writeConcernErrors halt the execution of this batch.
wce = result.get("writeConcernError", {})
if wce.get("code", 0) in _RETRYABLE_ERROR_CODES:
# Synthesize the full bulk result without modifying the
# current one because this write operation may be retried.
full = copy.deepcopy(full_result)
_merge_command(self.ops, self.idx_offset, full, result)
_throw_client_bulk_write_exception(full, self.verbose_results)
# Process the server reply as a command cursor.
await self._process_results_cursor(full_result, result, conn, session)
# Merge this batch's results with the full results.
_merge_command(self.ops, self.idx_offset, full_result, result)
# We're no longer in a retry once a command succeeds.
self.retrying = False
self.started_retryable_write = False
else:
to_send_ops, _ = await self._execute_batch_unack(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
self.idx_offset += len(to_send_ops)
# We halt execution if we hit a top-level error,
# or an individual error in an ordered bulk write.
if full_result["error"] or (self.ordered and full_result["writeErrors"]):
break
async def execute_command(
self,
session: Optional[AsyncClientSession],
operation: str,
) -> MutableMapping[str, Any]:
"""Execute commands with w=1 WriteConcern."""
full_result: MutableMapping[str, Any] = {
"anySuccessful": False,
"error": None,
"writeErrors": [],
"writeConcernErrors": [],
"nInserted": 0,
"nUpserted": 0,
"nMatched": 0,
"nModified": 0,
"nDeleted": 0,
"insertResults": {},
"updateResults": {},
"deleteResults": {},
}
op_id = _randint()
async def retryable_bulk(
session: Optional[AsyncClientSession],
conn: AsyncConnection,
retryable: bool,
) -> None:
if conn.max_wire_version < 25:
raise InvalidOperation(
"MongoClient.bulk_write requires MongoDB server version 8.0+."
)
await self._execute_command(
self.write_concern,
session,
conn,
op_id,
retryable,
full_result,
)
await self.client._retryable_write(
self.is_retryable,
retryable_bulk,
session,
operation,
bulk=self,
operation_id=op_id,
)
if full_result["error"] or full_result["writeErrors"] or full_result["writeConcernErrors"]:
_throw_client_bulk_write_exception(full_result, self.verbose_results)
return full_result
async def execute_command_unack_unordered(
self,
conn: AsyncConnection,
) -> None:
"""Execute commands with OP_MSG and w=0 writeConcern, unordered."""
db_name = "admin"
cmd_name = "bulkWrite"
listeners = self.client._event_listeners
op_id = _randint()
bwc = self.bulk_ctx_class(
db_name,
cmd_name,
conn,
op_id,
listeners, # type: ignore[arg-type]
None,
self.client.codec_options,
)
while self.idx_offset < self.total_ops:
# Construct the server command, specifying the relevant options.
cmd = {"bulkWrite": 1}
cmd["errorsOnly"] = not self.verbose_results
cmd["ordered"] = self.ordered # type: ignore[assignment]
if self.bypass_doc_val is not None:
cmd["bypassDocumentValidation"] = self.bypass_doc_val
cmd["writeConcern"] = {"w": 0} # type: ignore[assignment]
if self.comment:
cmd["comment"] = self.comment # type: ignore[assignment]
if self.let:
cmd["let"] = self.let
conn.add_server_api(cmd)
ops = islice(self.ops, self.idx_offset, None)
namespaces = islice(self.namespaces, self.idx_offset, None)
# Run as many ops as possible in one server command.
to_send_ops, _ = await self._execute_batch_unack(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
self.idx_offset += len(to_send_ops)
async def execute_command_unack_ordered(
self,
conn: AsyncConnection,
) -> None:
"""Execute commands with OP_MSG and w=0 WriteConcern, ordered."""
full_result: MutableMapping[str, Any] = {
"anySuccessful": False,
"error": None,
"writeErrors": [],
"writeConcernErrors": [],
"nInserted": 0,
"nUpserted": 0,
"nMatched": 0,
"nModified": 0,
"nDeleted": 0,
"insertResults": {},
"updateResults": {},
"deleteResults": {},
}
# Ordered bulk writes have to be acknowledged so that we stop
# processing at the first error, even when the application
# specified unacknowledged writeConcern.
initial_write_concern = WriteConcern()
op_id = _randint()
try:
await self._execute_command(
initial_write_concern,
None,
conn,
op_id,
False,
full_result,
self.write_concern,
)
except OperationFailure:
pass
async def execute_no_results(
self,
conn: AsyncConnection,
) -> None:
"""Execute all operations, returning no results (w=0)."""
if self.uses_collation:
raise ConfigurationError("Collation is unsupported for unacknowledged writes.")
if self.uses_array_filters:
raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.")
# Cannot have both unacknowledged writes and bypass document validation.
if self.bypass_doc_val is not None:
raise OperationFailure(
"Cannot set bypass_document_validation with unacknowledged write concern"
)
if self.ordered:
return await self.execute_command_unack_ordered(conn)
return await self.execute_command_unack_unordered(conn)
async def execute(
self,
session: Optional[AsyncClientSession],
operation: str,
) -> Any:
"""Execute operations."""
if not self.ops:
raise InvalidOperation("No operations to execute")
if self.executed:
raise InvalidOperation("Bulk operations can only be executed once.")
self.executed = True
session = _validate_session_write_concern(session, self.write_concern)
if not self.write_concern.acknowledged:
async with await self.client._conn_for_writes(session, operation) as connection:
if connection.max_wire_version < 25:
raise InvalidOperation(
"MongoClient.bulk_write requires MongoDB server version 8.0+."
)
await self.execute_no_results(connection)
return ClientBulkWriteResult(None, False, False) # type: ignore[arg-type]
result = await self.execute_command(session, operation)
return ClientBulkWriteResult(
result,
self.write_concern.acknowledged,
self.verbose_results,
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,470 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CommandCursor class to iterate over command results."""
from __future__ import annotations
from collections import deque
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Generic,
Mapping,
NoReturn,
Optional,
Sequence,
Union,
)
from bson import CodecOptions, _convert_raw_document_lists_to_streams
from pymongo import _csot
from pymongo.asynchronous.cursor import _ConnectionManager
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.message import (
_CursorAddress,
_GetMore,
_OpMsg,
_OpReply,
_RawBatchGetMore,
)
from pymongo.response import PinnedResponse
from pymongo.typings import _Address, _DocumentOut, _DocumentType
if TYPE_CHECKING:
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.pool import AsyncConnection
_IS_SYNC = False
class AsyncCommandCursor(Generic[_DocumentType]):
"""An asynchronous cursor / iterator over command cursors."""
_getmore_class = _GetMore
def __init__(
self,
collection: AsyncCollection[_DocumentType],
cursor_info: Mapping[str, Any],
address: Optional[_Address],
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[AsyncClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new command cursor."""
self._sock_mgr: Any = None
self._collection: AsyncCollection[_DocumentType] = collection
self._id = cursor_info["id"]
self._data = deque(cursor_info["firstBatch"])
self._postbatchresumetoken: Optional[Mapping[str, Any]] = cursor_info.get(
"postBatchResumeToken"
)
self._address = address
self._batch_size = batch_size
self._max_await_time_ms = max_await_time_ms
self._timeout = self._collection.database.client.options.timeout
self._session = session
self._explicit_session = explicit_session
self._killed = self._id == 0
self._comment = comment
if self._killed:
self._end_session()
if "ns" in cursor_info: # noqa: SIM401
self._ns = cursor_info["ns"]
else:
self._ns = collection.full_name
self.batch_size(batch_size)
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
raise TypeError("max_await_time_ms must be an integer or None")
def __del__(self) -> None:
self._die_no_lock()
def batch_size(self, batch_size: int) -> AsyncCommandCursor[_DocumentType]:
"""Limits the number of documents returned in one batch. Each batch
requires a round trip to the server. It can be adjusted to optimize
performance and limit data transfer.
.. note:: batch_size can not override MongoDB's internal limits on the
amount of data it will return to the client in a single batch (i.e
if you set batch size to 1,000,000,000, MongoDB will currently only
return 4-16MB of results per batch).
Raises :exc:`TypeError` if `batch_size` is not an integer.
Raises :exc:`ValueError` if `batch_size` is less than ``0``.
:param batch_size: The size of each batch of results requested.
"""
if not isinstance(batch_size, int):
raise TypeError("batch_size must be an integer")
if batch_size < 0:
raise ValueError("batch_size must be >= 0")
self._batch_size = batch_size == 1 and 2 or batch_size
return self
def _has_next(self) -> bool:
"""Returns `True` if the cursor has documents remaining from the
previous batch.
"""
return len(self._data) > 0
@property
def _post_batch_resume_token(self) -> Optional[Mapping[str, Any]]:
"""Retrieve the postBatchResumeToken from the response to a
changeStream aggregate or getMore.
"""
return self._postbatchresumetoken
async def _maybe_pin_connection(self, conn: AsyncConnection) -> None:
client = self._collection.database.client
if not client._should_pin_cursor(self._session):
return
if not self._sock_mgr:
conn.pin_cursor()
conn_mgr = _ConnectionManager(conn, False)
# Ensure the connection gets returned when the entire result is
# returned in the first batch.
if self._id == 0:
await conn_mgr.close()
else:
self._sock_mgr = conn_mgr
def _unpack_response(
self,
response: Union[_OpReply, _OpMsg],
cursor_id: Optional[int],
codec_options: CodecOptions[Mapping[str, Any]],
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> Sequence[_DocumentOut]:
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
@property
def alive(self) -> bool:
"""Does this cursor have the potential to return more data?
Even if :attr:`alive` is ``True``, :meth:`next` can raise
:exc:`StopIteration`. Best to use a for loop::
async for doc in collection.aggregate(pipeline):
print(doc)
.. note:: :attr:`alive` can be True while iterating a cursor from
a failed server. In this case :attr:`alive` will return False after
:meth:`next` fails to retrieve the next batch of results from the
server.
"""
return bool(len(self._data) or (not self._killed))
@property
def cursor_id(self) -> int:
"""Returns the id of the cursor."""
return self._id
@property
def address(self) -> Optional[_Address]:
"""The (host, port) of the server used, or None.
.. versionadded:: 3.0
"""
return self._address
@property
def session(self) -> Optional[AsyncClientSession]:
"""The cursor's :class:`~pymongo.asynchronous.client_session.AsyncClientSession`, or None.
.. versionadded:: 3.6
"""
if self._explicit_session:
return self._session
return None
def _prepare_to_die(self) -> tuple[int, Optional[_CursorAddress]]:
already_killed = self._killed
self._killed = True
if self._id and not already_killed:
cursor_id = self._id
assert self._address is not None
address = _CursorAddress(self._address, self._ns)
else:
# Skip killCursors.
cursor_id = 0
address = None
return cursor_id, address
def _die_no_lock(self) -> None:
"""Closes this cursor without acquiring a lock."""
cursor_id, address = self._prepare_to_die()
self._collection.database.client._cleanup_cursor_no_lock(
cursor_id, address, self._sock_mgr, self._session, self._explicit_session
)
if not self._explicit_session:
self._session = None
self._sock_mgr = None
async def _die_lock(self) -> None:
"""Closes this cursor."""
cursor_id, address = self._prepare_to_die()
await self._collection.database.client._cleanup_cursor_lock(
cursor_id,
address,
self._sock_mgr,
self._session,
self._explicit_session,
)
if not self._explicit_session:
self._session = None
self._sock_mgr = None
def _end_session(self) -> None:
if self._session and not self._explicit_session:
self._session._end_implicit_session()
self._session = None
async def close(self) -> None:
"""Explicitly close / kill this cursor."""
await self._die_lock()
async def _send_message(self, operation: _GetMore) -> None:
"""Send a getmore message and handle the response."""
client = self._collection.database.client
try:
response = await client._run_operation(
operation, self._unpack_response, address=self._address
)
except OperationFailure as exc:
if exc.code in _CURSOR_CLOSED_ERRORS:
# Don't send killCursors because the cursor is already closed.
self._killed = True
if exc.timeout:
self._die_no_lock()
else:
# Return the session and pinned connection, if necessary.
await self.close()
raise
except ConnectionFailure:
# Don't send killCursors because the cursor is already closed.
self._killed = True
# Return the session and pinned connection, if necessary.
await self.close()
raise
except Exception:
await self.close()
raise
if isinstance(response, PinnedResponse):
if not self._sock_mgr:
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type]
if response.from_command:
cursor = response.docs[0]["cursor"]
documents = cursor["nextBatch"]
self._postbatchresumetoken = cursor.get("postBatchResumeToken")
self._id = cursor["id"]
else:
documents = response.docs
assert isinstance(response.data, _OpReply)
self._id = response.data.cursor_id
if self._id == 0:
await self.close()
self._data = deque(documents)
async def _refresh(self) -> int:
"""Refreshes the cursor with more data from the server.
Returns the length of self._data after refresh. Will exit early if
self._data is already non-empty. Raises OperationFailure when the
cursor cannot be refreshed due to an error on the query.
"""
if len(self._data) or self._killed:
return len(self._data)
if self._id: # Get More
dbname, collname = self._ns.split(".", 1)
read_pref = self._collection._read_preference_for(self.session)
await self._send_message(
self._getmore_class(
dbname,
collname,
self._batch_size,
self._id,
self._collection.codec_options,
read_pref,
self._session,
self._collection.database.client,
self._max_await_time_ms,
self._sock_mgr,
False,
self._comment,
)
)
else: # Cursor id is zero nothing else to return
await self._die_lock()
return len(self._data)
def __aiter__(self) -> AsyncIterator[_DocumentType]:
return self
async def next(self) -> _DocumentType:
"""Advance the cursor."""
# Block until a document is returnable.
while self.alive:
doc = await self._try_next(True)
if doc is not None:
return doc
raise StopAsyncIteration
async def __anext__(self) -> _DocumentType:
return await self.next()
async def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]:
"""Advance the cursor blocking for at most one getMore command."""
if not len(self._data) and not self._killed and get_more_allowed:
await self._refresh()
if len(self._data):
return self._data.popleft()
else:
return None
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
"""Get all or some available documents from the cursor."""
if not len(self._data) and not self._killed:
await self._refresh()
if len(self._data):
if total is None:
result.extend(self._data)
self._data.clear()
else:
for _ in range(min(len(self._data), total)):
result.append(self._data.popleft())
return True
else:
return False
async def try_next(self) -> Optional[_DocumentType]:
"""Advance the cursor without blocking indefinitely.
This method returns the next document without waiting
indefinitely for data.
If no document is cached locally then this method runs a single
getMore command. If the getMore yields any documents, the next
document is returned, otherwise, if the getMore returns no documents
(because there is no additional data) then ``None`` is returned.
:return: The next document or ``None`` when no document is available
after running a single getMore or when the cursor is closed.
.. versionadded:: 4.5
"""
return await self._try_next(get_more_allowed=True)
async def __aenter__(self) -> AsyncCommandCursor[_DocumentType]:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()
@_csot.apply
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
To use::
>>> await cursor.to_list()
Or, so read at most n items from the cursor::
>>> await cursor.to_list(n)
If the cursor is empty or has no more results, an empty list will be returned.
.. versionadded:: 4.9
"""
res: list[_DocumentType] = []
remaining = length
if isinstance(length, int) and length < 1:
raise ValueError("to_list() length must be greater than 0")
while self.alive:
if not await self._next_batch(res, remaining):
break
if length is not None:
remaining = length - len(res)
if remaining == 0:
break
return res
class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]):
_getmore_class = _RawBatchGetMore
def __init__(
self,
collection: AsyncCollection[_DocumentType],
cursor_info: Mapping[str, Any],
address: Optional[_Address],
batch_size: int = 0,
max_await_time_ms: Optional[int] = None,
session: Optional[AsyncClientSession] = None,
explicit_session: bool = False,
comment: Any = None,
) -> None:
"""Create a new cursor / iterator over raw batches of BSON data.
Should not be called directly by application developers -
see :meth:`~pymongo.asynchronous.collection.AsyncCollection.aggregate_raw_batches`
instead.
.. seealso:: The MongoDB documentation on `cursors <https://dochub.mongodb.org/core/cursors>`_.
"""
assert not cursor_info.get("firstBatch")
super().__init__(
collection,
cursor_info,
address,
batch_size,
max_await_time_ms,
session,
explicit_session,
comment,
)
def _unpack_response( # type: ignore[override]
self,
response: Union[_OpReply, _OpMsg],
cursor_id: Optional[int],
codec_options: CodecOptions,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> list[Mapping[str, Any]]:
raw_response = response.raw_response(cursor_id, user_fields=user_fields)
if not legacy_response:
# OP_MSG returns firstBatch/nextBatch documents as a BSON array
# Re-assemble the array of documents into a document stream
_convert_raw_document_lists_to_streams(raw_response[0])
return raw_response # type: ignore[return-value]
def __getitem__(self, index: int) -> NoReturn:
raise InvalidOperation("Cannot call __getitem__ on AsyncRawBatchCommandCursor")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,82 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Miscellaneous pieces that need to be synchronized."""
from __future__ import annotations
import builtins
import sys
from typing import (
Any,
Callable,
TypeVar,
cast,
)
from pymongo.errors import (
OperationFailure,
)
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE
_IS_SYNC = False
# See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories
F = TypeVar("F", bound=Callable[..., Any])
def _handle_reauth(func: F) -> F:
async def inner(*args: Any, **kwargs: Any) -> Any:
no_reauth = kwargs.pop("no_reauth", False)
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.message import _BulkWriteContext
try:
return await func(*args, **kwargs)
except OperationFailure as exc:
if no_reauth:
raise
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
# Look for an argument that either is a AsyncConnection
# or has a connection attribute, so we can trigger
# a reauth.
conn = None
for arg in args:
if isinstance(arg, AsyncConnection):
conn = arg
break
if isinstance(arg, _BulkWriteContext):
conn = arg.conn # type: ignore[assignment]
break
if conn:
await conn.authenticate(reauthenticate=True)
else:
raise
return func(*args, **kwargs)
raise
return cast(F, inner)
if sys.version_info >= (3, 10):
anext = builtins.anext
aiter = builtins.aiter
else:
async def anext(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext."""
return await cls.__anext__()
def aiter(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext."""
return cls.__aiter__()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,534 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Class to monitor a MongoDB server on a background thread."""
from __future__ import annotations
import atexit
import logging
import time
import weakref
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
from pymongo import common
from pymongo._csot import MovingMinimum
from pymongo.asynchronous import periodic_executor
from pymongo.asynchronous.periodic_executor import _shutdown_executors
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
from pymongo.hello import Hello
from pymongo.lock import _create_lock
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
from pymongo.pool_options import _is_faas
from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription
from pymongo.srv_resolver import _SrvResolver
if TYPE_CHECKING:
from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology
_IS_SYNC = False
def _sanitize(error: Exception) -> None:
"""PYTHON-2433 Clear error traceback info."""
error.__traceback__ = None
error.__context__ = None
error.__cause__ = None
def _monotonic_duration(start: float) -> float:
"""Return the duration since the given start time.
Accounts for buggy platforms where time.monotonic() is not monotonic.
See PYTHON-4600.
"""
return max(0.0, time.monotonic() - start)
class MonitorBase:
def __init__(self, topology: Topology, name: str, interval: int, min_interval: float):
"""Base class to do periodic work on a background thread.
The background thread is signaled to stop when the Topology or
this instance is freed.
"""
# We strongly reference the executor and it weakly references us via
# this closure. When the monitor is freed, stop the executor soon.
async def target() -> bool:
monitor = self_ref()
if monitor is None:
return False # Stop the executor.
await monitor._run() # type:ignore[attr-defined]
return True
executor = periodic_executor.PeriodicExecutor(
interval=interval, min_interval=min_interval, target=target, name=name
)
self._executor = executor
def _on_topology_gc(dummy: Optional[Topology] = None) -> None:
# This prevents GC from waiting 10 seconds for hello to complete
# See test_cleanup_executors_on_client_del.
monitor = self_ref()
if monitor:
monitor.gc_safe_close()
# Avoid cycles. When self or topology is freed, stop executor soon.
self_ref = weakref.ref(self, executor.close)
self._topology = weakref.proxy(topology, _on_topology_gc)
_register(self)
def open(self) -> None:
"""Start monitoring, or restart after a fork.
Multiple calls have no effect.
"""
self._executor.open()
def gc_safe_close(self) -> None:
"""GC safe close."""
self._executor.close()
async def close(self) -> None:
"""Close and stop monitoring.
open() restarts the monitor after closing.
"""
self.gc_safe_close()
def join(self, timeout: Optional[int] = None) -> None:
"""Wait for the monitor to stop."""
self._executor.join(timeout)
def request_check(self) -> None:
"""If the monitor is sleeping, wake it soon."""
self._executor.wake()
class Monitor(MonitorBase):
def __init__(
self,
server_description: ServerDescription,
topology: Topology,
pool: Pool,
topology_settings: TopologySettings,
):
"""Class to monitor a MongoDB server on a background thread.
Pass an initial ServerDescription, a Topology, a Pool, and
TopologySettings.
The Topology is weakly referenced. The Pool must be exclusive to this
Monitor.
"""
super().__init__(
topology,
"pymongo_server_monitor_thread",
topology_settings.heartbeat_frequency,
common.MIN_HEARTBEAT_INTERVAL,
)
self._server_description = server_description
self._pool = pool
self._settings = topology_settings
self._listeners = self._settings._pool_options._event_listeners
self._publish = self._listeners is not None and self._listeners.enabled_for_server_heartbeat
self._cancel_context: Optional[_CancellationContext] = None
self._rtt_monitor = _RttMonitor(
topology,
topology_settings,
topology._create_pool_for_monitor(server_description.address),
)
if topology_settings.server_monitoring_mode == "stream":
self._stream = True
elif topology_settings.server_monitoring_mode == "poll":
self._stream = False
else:
self._stream = not _is_faas()
def cancel_check(self) -> None:
"""Cancel any concurrent hello check.
Note: this is called from a weakref.proxy callback and MUST NOT take
any locks.
"""
context = self._cancel_context
if context:
# Note: we cannot close the socket because doing so may cause
# concurrent reads/writes to hang until a timeout occurs
# (depending on the platform).
context.cancel()
async def _start_rtt_monitor(self) -> None:
"""Start an _RttMonitor that periodically runs ping."""
# If this monitor is closed directly before (or during) this open()
# call, the _RttMonitor will not be closed. Checking if this monitor
# was closed directly after resolves the race.
self._rtt_monitor.open()
if self._executor._stopped:
await self._rtt_monitor.close()
def gc_safe_close(self) -> None:
self._executor.close()
self._rtt_monitor.gc_safe_close()
self.cancel_check()
async def close(self) -> None:
self.gc_safe_close()
await self._rtt_monitor.close()
# Increment the generation and maybe close the socket. If the executor
# thread has the socket checked out, it will be closed when checked in.
await self._reset_connection()
async def _reset_connection(self) -> None:
# Clear our pooled connection.
await self._pool.reset()
async def _run(self) -> None:
try:
prev_sd = self._server_description
try:
self._server_description = await self._check_server()
except _OperationCancelled as exc:
_sanitize(exc)
# Already closed the connection, wait for the next check.
self._server_description = ServerDescription(
self._server_description.address, error=exc
)
if prev_sd.is_server_type_known:
# Immediately retry since we've already waited 500ms to
# discover that we've been cancelled.
self._executor.skip_sleep()
return
# Update the Topology and clear the server pool on error.
await self._topology.on_change(
self._server_description,
reset_pool=self._server_description.error,
interrupt_connections=isinstance(self._server_description.error, NetworkTimeout),
)
if self._stream and (
self._server_description.is_server_type_known
and self._server_description.topology_version
):
await self._start_rtt_monitor()
# Immediately check for the next streaming response.
self._executor.skip_sleep()
if self._server_description.error and prev_sd.is_server_type_known:
# Immediately retry on network errors.
self._executor.skip_sleep()
except ReferenceError:
# Topology was garbage-collected.
await self.close()
async def _check_server(self) -> ServerDescription:
"""Call hello or read the next streaming response.
Returns a ServerDescription.
"""
start = time.monotonic()
try:
try:
return await self._check_once()
except (OperationFailure, NotPrimaryError) as exc:
# Update max cluster time even when hello fails.
details = cast(Mapping[str, Any], exc.details)
self._topology.receive_cluster_time(details.get("$clusterTime"))
raise
except ReferenceError:
raise
except Exception as error:
_sanitize(error)
sd = self._server_description
address = sd.address
duration = _monotonic_duration(start)
awaited = bool(self._stream and sd.is_server_type_known and sd.topology_version)
if self._publish:
assert self._listeners is not None
self._listeners.publish_server_heartbeat_failed(address, duration, error, awaited)
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_SDAM_LOGGER,
topologyId=self._topology._topology_id,
serverHost=address[0],
serverPort=address[1],
awaited=awaited,
durationMS=duration * 1000,
failure=error,
message=_SDAMStatusMessage.HEARTBEAT_FAIL,
)
await self._reset_connection()
if isinstance(error, _OperationCancelled):
raise
self._rtt_monitor.reset()
# Server type defaults to Unknown.
return ServerDescription(address, error=error)
async def _check_once(self) -> ServerDescription:
"""A single attempt to call hello.
Returns a ServerDescription, or raises an exception.
"""
address = self._server_description.address
sd = self._server_description
# XXX: "awaited" could be incorrectly set to True in the rare case
# the pool checkout closes and recreates a connection.
awaited = bool(
self._pool.conns and self._stream and sd.is_server_type_known and sd.topology_version
)
if self._publish:
assert self._listeners is not None
self._listeners.publish_server_heartbeat_started(address, awaited)
if self._cancel_context and self._cancel_context.cancelled:
await self._reset_connection()
async with self._pool.checkout() as conn:
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_SDAM_LOGGER,
topologyId=self._topology._topology_id,
driverConnectionId=conn.id,
serverConnectionId=conn.server_connection_id,
serverHost=address[0],
serverPort=address[1],
awaited=awaited,
message=_SDAMStatusMessage.HEARTBEAT_START,
)
self._cancel_context = conn.cancel_context
response, round_trip_time = await self._check_with_socket(conn)
if not response.awaitable:
self._rtt_monitor.add_sample(round_trip_time)
avg_rtt, min_rtt = self._rtt_monitor.get()
sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt)
if self._publish:
assert self._listeners is not None
self._listeners.publish_server_heartbeat_succeeded(
address, round_trip_time, response, response.awaitable
)
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_SDAM_LOGGER,
topologyId=self._topology._topology_id,
driverConnectionId=conn.id,
serverConnectionId=conn.server_connection_id,
serverHost=address[0],
serverPort=address[1],
awaited=awaited,
durationMS=round_trip_time * 1000,
reply=response.document,
message=_SDAMStatusMessage.HEARTBEAT_SUCCESS,
)
return sd
async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float]:
"""Return (Hello, round_trip_time).
Can raise ConnectionFailure or OperationFailure.
"""
cluster_time = self._topology.max_cluster_time()
start = time.monotonic()
if conn.more_to_come:
# Read the next streaming hello (MongoDB 4.4+).
response = Hello(await conn._next_reply(), awaitable=True)
elif (
self._stream and conn.performed_handshake and self._server_description.topology_version
):
# Initiate streaming hello (MongoDB 4.4+).
response = await conn._hello(
cluster_time,
self._server_description.topology_version,
self._settings.heartbeat_frequency,
)
else:
# New connection handshake or polling hello (MongoDB <4.4).
response = await conn._hello(cluster_time, None, None)
duration = _monotonic_duration(start)
return response, duration
class SrvMonitor(MonitorBase):
def __init__(self, topology: Topology, topology_settings: TopologySettings):
"""Class to poll SRV records on a background thread.
Pass a Topology and a TopologySettings.
The Topology is weakly referenced.
"""
super().__init__(
topology,
"pymongo_srv_polling_thread",
common.MIN_SRV_RESCAN_INTERVAL,
topology_settings.heartbeat_frequency,
)
self._settings = topology_settings
self._seedlist = self._settings._seeds
assert isinstance(self._settings.fqdn, str)
self._fqdn: str = self._settings.fqdn
self._startup_time = time.monotonic()
async def _run(self) -> None:
# Don't poll right after creation, wait 60 seconds first
if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL:
return
seedlist = self._get_seedlist()
if seedlist:
self._seedlist = seedlist
try:
await self._topology.on_srv_update(self._seedlist)
except ReferenceError:
# Topology was garbage-collected.
await self.close()
def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
"""Poll SRV records for a seedlist.
Returns a list of ServerDescriptions.
"""
try:
resolver = _SrvResolver(
self._fqdn,
self._settings.pool_options.connect_timeout,
self._settings.srv_service_name,
)
seedlist, ttl = resolver.get_hosts_and_min_ttl()
if len(seedlist) == 0:
# As per the spec: this should be treated as a failure.
raise Exception
except Exception:
# As per the spec, upon encountering an error:
# - An error must not be raised
# - SRV records must be rescanned every heartbeatFrequencyMS
# - Topology must be left unchanged
self.request_check()
return None
else:
self._executor.update_interval(max(ttl, common.MIN_SRV_RESCAN_INTERVAL))
return seedlist
class _RttMonitor(MonitorBase):
def __init__(self, topology: Topology, topology_settings: TopologySettings, pool: Pool):
"""Maintain round trip times for a server.
The Topology is weakly referenced.
"""
super().__init__(
topology,
"pymongo_server_rtt_thread",
topology_settings.heartbeat_frequency,
common.MIN_HEARTBEAT_INTERVAL,
)
self._pool = pool
self._moving_average = MovingAverage()
self._moving_min = MovingMinimum()
self._lock = _create_lock()
async def close(self) -> None:
self.gc_safe_close()
# Increment the generation and maybe close the socket. If the executor
# thread has the socket checked out, it will be closed when checked in.
await self._pool.reset()
def add_sample(self, sample: float) -> None:
"""Add a RTT sample."""
with self._lock:
self._moving_average.add_sample(sample)
self._moving_min.add_sample(sample)
def get(self) -> tuple[Optional[float], float]:
"""Get the calculated average, or None if no samples yet and the min."""
with self._lock:
return self._moving_average.get(), self._moving_min.get()
def reset(self) -> None:
"""Reset the average RTT."""
with self._lock:
self._moving_average.reset()
self._moving_min.reset()
async def _run(self) -> None:
try:
# NOTE: This thread is only run when using the streaming
# heartbeat protocol (MongoDB 4.4+).
# XXX: Skip check if the server is unknown?
rtt = await self._ping()
self.add_sample(rtt)
except ReferenceError:
# Topology was garbage-collected.
await self.close()
except Exception:
await self._pool.reset()
async def _ping(self) -> float:
"""Run a "hello" command and return the RTT."""
async with self._pool.checkout() as conn:
if self._executor._stopped:
raise Exception("_RttMonitor closed")
start = time.monotonic()
await conn.hello()
return _monotonic_duration(start)
# Close monitors to cancel any in progress streaming checks before joining
# executor threads. For an explanation of how this works see the comment
# about _EXECUTORS in periodic_executor.py.
_MONITORS = set()
def _register(monitor: MonitorBase) -> None:
ref = weakref.ref(monitor, _unregister)
_MONITORS.add(ref)
def _unregister(monitor_ref: weakref.ReferenceType[MonitorBase]) -> None:
_MONITORS.remove(monitor_ref)
def _shutdown_monitors() -> None:
if _MONITORS is None:
return
# Copy the set. Closing monitors removes them.
monitors = list(_MONITORS)
# Close all monitors.
for ref in monitors:
monitor = ref()
if monitor:
monitor.gc_safe_close()
monitor = None
def _shutdown_resources() -> None:
# _shutdown_monitors/_shutdown_executors may already be GC'd at shutdown.
shutdown = _shutdown_monitors
if shutdown: # type:ignore[truthy-function]
shutdown()
shutdown = _shutdown_executors
if shutdown: # type:ignore[truthy-function]
shutdown()
atexit.register(_shutdown_resources)

View File

@@ -0,0 +1,414 @@
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Internal network layer helper methods."""
from __future__ import annotations
import asyncio
import datetime
import errno
import logging
import socket
import time
from typing import (
TYPE_CHECKING,
Any,
Mapping,
MutableMapping,
Optional,
Sequence,
Union,
cast,
)
from bson import _decode_all_selective
from pymongo import _csot, helpers_shared, message
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.compression_support import _NO_COMPRESSION, decompress
from pymongo.errors import (
NotPrimaryError,
OperationFailure,
ProtocolError,
_OperationCancelled,
)
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.monitoring import _is_speculative_authenticate
from pymongo.network_layer import (
_POLL_TIMEOUT,
_UNPACK_COMPRESSION_HEADER,
_UNPACK_HEADER,
BLOCKING_IO_ERRORS,
async_sendall,
)
from pymongo.socket_checker import _errno_from_exception
if TYPE_CHECKING:
from bson import CodecOptions
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
from pymongo.monitoring import _EventListeners
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _ServerMode
from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
from pymongo.write_concern import WriteConcern
_IS_SYNC = False
async def command(
conn: AsyncConnection,
dbname: str,
spec: MutableMapping[str, Any],
is_mongos: bool,
read_preference: Optional[_ServerMode],
codec_options: CodecOptions[_DocumentType],
session: Optional[AsyncClientSession],
client: Optional[AsyncMongoClient],
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
address: Optional[_Address] = None,
listeners: Optional[_EventListeners] = None,
max_bson_size: Optional[int] = None,
read_concern: Optional[ReadConcern] = None,
parse_write_concern_error: bool = False,
collation: Optional[_CollationIn] = None,
compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
use_op_msg: bool = False,
unacknowledged: bool = False,
user_fields: Optional[Mapping[str, Any]] = None,
exhaust_allowed: bool = False,
write_concern: Optional[WriteConcern] = None,
) -> _DocumentType:
"""Execute a command over the socket, or raise socket.error.
:param conn: a AsyncConnection instance
:param dbname: name of the database on which to run the command
:param spec: a command document as an ordered dict type, eg SON.
:param is_mongos: are we connected to a mongos?
:param read_preference: a read preference
:param codec_options: a CodecOptions instance
:param session: optional AsyncClientSession instance.
:param client: optional AsyncMongoClient instance for updating $clusterTime.
:param check: raise OperationFailure if there are errors
:param allowable_errors: errors to ignore if `check` is True
:param address: the (host, port) of `conn`
:param listeners: An instance of :class:`~pymongo.monitoring.EventListeners`
:param max_bson_size: The maximum encoded bson size for this server
:param read_concern: The read concern for this command.
:param parse_write_concern_error: Whether to parse the ``writeConcernError``
field in the command response.
:param collation: The collation for this command.
:param compression_ctx: optional compression Context.
:param use_op_msg: True if we should use OP_MSG.
:param unacknowledged: True if this is an unacknowledged command.
:param user_fields: Response fields that should be decoded
using the TypeDecoders from codec_options, passed to
bson._decode_all_selective.
:param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed.
"""
name = next(iter(spec))
ns = dbname + ".$cmd"
speculative_hello = False
# Publish the original command document, perhaps with lsid and $clusterTime.
orig = spec
if is_mongos and not use_op_msg:
assert read_preference is not None
spec = message._maybe_add_read_preference(spec, read_preference)
if read_concern and not (session and session.in_transaction):
if read_concern.level:
spec["readConcern"] = read_concern.document
if session:
session._update_read_concern(spec, conn)
if collation is not None:
spec["collation"] = collation
publish = listeners is not None and listeners.enabled_for_commands
start = datetime.datetime.now()
if publish:
speculative_hello = _is_speculative_authenticate(name, spec)
if compression_ctx and name.lower() in _NO_COMPRESSION:
compression_ctx = None
if client and client._encrypter and not client._encrypter._bypass_auto_encryption:
spec = orig = await client._encrypter.encrypt(dbname, spec, codec_options)
# Support CSOT
if client:
conn.apply_timeout(client, spec)
_csot.apply_write_concern(spec, write_concern)
if use_op_msg:
flags = _OpMsg.MORE_TO_COME if unacknowledged else 0
flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0
request_id, msg, size, max_doc_size = message._op_msg(
flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx
)
# If this is an unacknowledged write then make sure the encoded doc(s)
# are small enough, otherwise rely on the server to return an error.
if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size:
message._raise_document_too_large(name, size, max_bson_size)
else:
request_id, msg, size = message._query(
0, ns, 0, -1, spec, None, codec_options, compression_ctx
)
if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD:
message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD)
if client is not None:
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.STARTED,
command=spec,
commandName=next(iter(spec)),
databaseName=dbname,
requestId=request_id,
operationId=request_id,
driverConnectionId=conn.id,
serverConnectionId=conn.server_connection_id,
serverHost=conn.address[0],
serverPort=conn.address[1],
serviceId=conn.service_id,
)
if publish:
assert listeners is not None
assert address is not None
listeners.publish_command_start(
orig,
dbname,
request_id,
address,
conn.server_connection_id,
service_id=conn.service_id,
)
try:
await async_sendall(conn.conn, msg)
if use_op_msg and unacknowledged:
# Unacknowledged, fake a successful command response.
reply = None
response_doc: _DocumentOut = {"ok": 1}
else:
reply = await receive_message(conn, request_id)
conn.more_to_come = reply.more_to_come
unpacked_docs = reply.unpack_response(
codec_options=codec_options, user_fields=user_fields
)
response_doc = unpacked_docs[0]
if client:
await client._process_response(response_doc, session)
if check:
helpers_shared._check_command_response(
response_doc,
conn.max_wire_version,
allowable_errors,
parse_write_concern_error=parse_write_concern_error,
)
except Exception as exc:
duration = datetime.datetime.now() - start
if isinstance(exc, (NotPrimaryError, OperationFailure)):
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = message._convert_exception(exc)
if client is not None:
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.FAILED,
durationMS=duration,
failure=failure,
commandName=next(iter(spec)),
databaseName=dbname,
requestId=request_id,
operationId=request_id,
driverConnectionId=conn.id,
serverConnectionId=conn.server_connection_id,
serverHost=conn.address[0],
serverPort=conn.address[1],
serviceId=conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if publish:
assert listeners is not None
assert address is not None
listeners.publish_command_failure(
duration,
failure,
name,
request_id,
address,
conn.server_connection_id,
service_id=conn.service_id,
database_name=dbname,
)
raise
duration = datetime.datetime.now() - start
if client is not None:
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.SUCCEEDED,
durationMS=duration,
reply=response_doc,
commandName=next(iter(spec)),
databaseName=dbname,
requestId=request_id,
operationId=request_id,
driverConnectionId=conn.id,
serverConnectionId=conn.server_connection_id,
serverHost=conn.address[0],
serverPort=conn.address[1],
serviceId=conn.service_id,
speculative_authenticate="speculativeAuthenticate" in orig,
)
if publish:
assert listeners is not None
assert address is not None
listeners.publish_command_success(
duration,
response_doc,
name,
request_id,
address,
conn.server_connection_id,
service_id=conn.service_id,
speculative_hello=speculative_hello,
database_name=dbname,
)
if client and client._encrypter and reply:
decrypted = await client._encrypter.decrypt(reply.raw_command_response())
response_doc = cast(
"_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0]
)
return response_doc # type: ignore[return-value]
async def receive_message(
conn: AsyncConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
if _csot.get_timeout():
deadline = _csot.get_deadline()
else:
timeout = conn.conn.gettimeout()
if timeout:
deadline = time.monotonic() + timeout
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(
await _receive_data_on_socket(conn, 16, deadline)
)
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
if length <= 16:
raise ProtocolError(
f"Message length ({length!r}) not longer than standard message header size (16)"
)
if length > max_message_size:
raise ProtocolError(
f"Message length ({length!r}) is larger than server max "
f"message size ({max_message_size!r})"
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
await _receive_data_on_socket(conn, 9, deadline)
)
data = decompress(await _receive_data_on_socket(conn, length - 25, deadline), compressor_id)
else:
data = await _receive_data_on_socket(conn, length - 16, deadline)
try:
unpack_reply = _UNPACK_REPLY[op_code]
except KeyError:
raise ProtocolError(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)
async def wait_for_read(conn: AsyncConnection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
sock = conn.conn
timed_out = False
# Check if the connection's socket has been manually closed
if sock.fileno() == -1:
return
while True:
# SSLSocket can have buffered data which won't be caught by select.
if hasattr(sock, "pending") and sock.pending() > 0:
readable = True
else:
# Wait up to 500ms for the socket to become readable and then
# check for cancellation.
if deadline:
remaining = deadline - time.monotonic()
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
if remaining <= 0:
timed_out = True
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else:
timeout = _POLL_TIMEOUT
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled")
if readable:
return
if timed_out:
raise socket.timeout("timed out")
await asyncio.sleep(0)
async def _receive_data_on_socket(
conn: AsyncConnection, length: int, deadline: Optional[float]
) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
while bytes_read < length:
try:
await wait_for_read(conn, deadline)
# CSOT: Update timeout. When the timeout has expired perform one
# final non-blocking recv. This helps avoid spurious timeouts when
# the response is actually already buffered on the client.
if _csot.get_timeout() and deadline is not None:
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
chunk_length = conn.conn.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except OSError as exc:
if _errno_from_exception(exc) == errno.EINTR:
continue
raise
if chunk_length == 0:
raise OSError("connection closed")
bytes_read += chunk_length
return mv

View File

@@ -0,0 +1,219 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Run a target function on a background thread."""
from __future__ import annotations
import asyncio
import sys
import threading
import time
import weakref
from typing import Any, Optional
from pymongo.lock import _ALock, _create_lock
_IS_SYNC = False
class PeriodicExecutor:
def __init__(
self,
interval: float,
min_interval: float,
target: Any,
name: Optional[str] = None,
):
"""Run a target function periodically on a background thread.
If the target's return value is false, the executor stops.
:param interval: Seconds between calls to `target`.
:param min_interval: Minimum seconds between calls if `wake` is
called very often.
:param target: A function.
:param name: A name to give the underlying thread.
"""
# threading.Event and its internal condition variable are expensive
# in Python 2, see PYTHON-983. Use a boolean to know when to wake.
# The executor's design is constrained by several Python issues, see
# "periodic_executor.rst" in this repository.
self._event = False
self._interval = interval
self._min_interval = min_interval
self._target = target
self._stopped = False
self._thread: Optional[threading.Thread] = None
self._name = name
self._skip_sleep = False
self._thread_will_exit = False
self._lock = _ALock(_create_lock())
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
def _run_async(self) -> None:
# The default asyncio loop implementation on Windows
# has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240)
# We explicitly use a different loop implementation here to prevent that issue
if sys.platform == "win32":
loop = asyncio.SelectorEventLoop()
try:
loop.run_until_complete(self._run()) # type: ignore[func-returns-value]
finally:
loop.close()
else:
asyncio.run(self._run()) # type: ignore[func-returns-value]
def open(self) -> None:
"""Start. Multiple calls have no effect.
Not safe to call from multiple threads at once.
"""
with self._lock:
if self._thread_will_exit:
# If the background thread has read self._stopped as True
# there is a chance that it has not yet exited. The call to
# join should not block indefinitely because there is no
# other work done outside the while loop in self._run.
try:
assert self._thread is not None
self._thread.join()
except ReferenceError:
# Thread terminated.
pass
self._thread_will_exit = False
self._stopped = False
started: Any = False
try:
started = self._thread and self._thread.is_alive()
except ReferenceError:
# Thread terminated.
pass
if not started:
if _IS_SYNC:
thread = threading.Thread(target=self._run, name=self._name)
else:
thread = threading.Thread(target=self._run_async, name=self._name)
thread.daemon = True
self._thread = weakref.proxy(thread)
_register_executor(self)
# Mitigation to RuntimeError firing when thread starts on shutdown
# https://github.com/python/cpython/issues/114570
try:
thread.start()
except RuntimeError as e:
if "interpreter shutdown" in str(e) or sys.is_finalizing():
self._thread = None
return
raise
def close(self, dummy: Any = None) -> None:
"""Stop. To restart, call open().
The dummy parameter allows an executor's close method to be a weakref
callback; see monitor.py.
"""
self._stopped = True
def join(self, timeout: Optional[int] = None) -> None:
if self._thread is not None:
try:
self._thread.join(timeout)
except (ReferenceError, RuntimeError):
# Thread already terminated, or not yet started.
pass
def wake(self) -> None:
"""Execute the target function soon."""
self._event = True
def update_interval(self, new_interval: int) -> None:
self._interval = new_interval
def skip_sleep(self) -> None:
self._skip_sleep = True
async def _should_stop(self) -> bool:
async with self._lock:
if self._stopped:
self._thread_will_exit = True
return True
return False
async def _run(self) -> None:
while not await self._should_stop():
try:
if not await self._target():
self._stopped = True
break
except BaseException:
async with self._lock:
self._stopped = True
self._thread_will_exit = True
raise
if self._skip_sleep:
self._skip_sleep = False
else:
deadline = time.monotonic() + self._interval
while not self._stopped and time.monotonic() < deadline:
await asyncio.sleep(self._min_interval)
if self._event:
break # Early wake.
self._event = False
# _EXECUTORS has a weakref to each running PeriodicExecutor. Once started,
# an executor is kept alive by a strong reference from its thread and perhaps
# from other objects. When the thread dies and all other referrers are freed,
# the executor is freed and removed from _EXECUTORS. If any threads are
# running when the interpreter begins to shut down, we try to halt and join
# them to avoid spurious errors.
_EXECUTORS = set()
def _register_executor(executor: PeriodicExecutor) -> None:
ref = weakref.ref(executor, _on_executor_deleted)
_EXECUTORS.add(ref)
def _on_executor_deleted(ref: weakref.ReferenceType[PeriodicExecutor]) -> None:
_EXECUTORS.remove(ref)
def _shutdown_executors() -> None:
if _EXECUTORS is None:
return
# Copy the set. Stopping threads has the side effect of removing executors.
executors = list(_EXECUTORS)
# First signal all executors to close...
for ref in executors:
executor = ref()
if executor:
executor.close()
# ...then try to join them.
for ref in executors:
executor = ref()
if executor:
executor.join(1)
executor = None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,383 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Communicate with one MongoDB server in a topology."""
from __future__ import annotations
import logging
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
Callable,
Optional,
Union,
)
from bson import _decode_all_selective
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.errors import NotPrimaryError, OperationFailure
from pymongo.helpers_shared import _check_command_response
from pymongo.logger import (
_COMMAND_LOGGER,
_SDAM_LOGGER,
_CommandStatusMessage,
_debug_log,
_SDAMStatusMessage,
)
from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query
from pymongo.response import PinnedResponse, Response
if TYPE_CHECKING:
from queue import Queue
from weakref import ReferenceType
from bson.objectid import ObjectId
from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler
from pymongo.asynchronous.monitor import Monitor
from pymongo.asynchronous.pool import AsyncConnection, Pool
from pymongo.monitoring import _EventListeners
from pymongo.read_preferences import _ServerMode
from pymongo.server_description import ServerDescription
from pymongo.typings import _DocumentOut
_IS_SYNC = False
_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}}
class Server:
def __init__(
self,
server_description: ServerDescription,
pool: Pool,
monitor: Monitor,
topology_id: Optional[ObjectId] = None,
listeners: Optional[_EventListeners] = None,
events: Optional[ReferenceType[Queue]] = None,
) -> None:
"""Represent one MongoDB server."""
self._description = server_description
self._pool = pool
self._monitor = monitor
self._topology_id = topology_id
self._publish = listeners is not None and listeners.enabled_for_server
self._listener = listeners
self._events = None
if self._publish:
self._events = events() # type: ignore[misc]
async def open(self) -> None:
"""Start monitoring, or restart after a fork.
Multiple calls have no effect.
"""
if not self._pool.opts.load_balanced:
self._monitor.open()
async def reset(self, service_id: Optional[ObjectId] = None) -> None:
"""Clear the connection pool."""
await self.pool.reset(service_id)
async def close(self) -> None:
"""Clear the connection pool and stop the monitor.
Reconnect with open().
"""
if self._publish:
assert self._listener is not None
assert self._events is not None
self._events.put(
(
self._listener.publish_server_closed,
(self._description.address, self._topology_id),
)
)
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_SDAM_LOGGER,
topologyId=self._topology_id,
serverHost=self._description.address[0],
serverPort=self._description.address[1],
message=_SDAMStatusMessage.STOP_SERVER,
)
await self._monitor.close()
await self._pool.close()
def request_check(self) -> None:
"""Check the server's state soon."""
self._monitor.request_check()
async def operation_to_command(
self, operation: Union[_Query, _GetMore], conn: AsyncConnection, apply_timeout: bool = False
) -> tuple[dict[str, Any], str]:
cmd, db = operation.as_command(conn, apply_timeout)
# Support auto encryption
if operation.client._encrypter and not operation.client._encrypter._bypass_auto_encryption:
cmd = await operation.client._encrypter.encrypt( # type: ignore[misc, assignment]
operation.db, cmd, operation.codec_options
)
operation.update_command(cmd)
return cmd, db
@_handle_reauth
async def run_operation(
self,
conn: AsyncConnection,
operation: Union[_Query, _GetMore],
read_preference: _ServerMode,
listeners: Optional[_EventListeners],
unpack_res: Callable[..., list[_DocumentOut]],
client: AsyncMongoClient,
) -> Response:
"""Run a _Query or _GetMore operation and return a Response object.
This method is used only to run _Query/_GetMore operations from
cursors.
Can raise ConnectionFailure, OperationFailure, etc.
:param conn: An AsyncConnection instance.
:param operation: A _Query or _GetMore object.
:param read_preference: The read preference to use.
:param listeners: Instance of _EventListeners or None.
:param unpack_res: A callable that decodes the wire protocol response.
:param client: An AsyncMongoClient instance.
"""
assert listeners is not None
publish = listeners.enabled_for_commands
start = datetime.now()
use_cmd = operation.use_command(conn)
more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come
cmd, dbn = await self.operation_to_command(operation, conn, use_cmd)
if more_to_come:
request_id = 0
else:
message = operation.get_message(read_preference, conn, use_cmd)
request_id, data, max_doc_size = self._split_message(message)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.STARTED,
command=cmd,
commandName=next(iter(cmd)),
databaseName=dbn,
requestId=request_id,
operationId=request_id,
driverConnectionId=conn.id,
serverConnectionId=conn.server_connection_id,
serverHost=conn.address[0],
serverPort=conn.address[1],
serviceId=conn.service_id,
)
if publish:
if "$db" not in cmd:
cmd["$db"] = dbn
assert listeners is not None
listeners.publish_command_start(
cmd,
dbn,
request_id,
conn.address,
conn.server_connection_id,
service_id=conn.service_id,
)
try:
if more_to_come:
reply = await conn.receive_message(None)
else:
await conn.send_message(data, max_doc_size)
reply = await conn.receive_message(request_id)
# Unpack and check for command errors.
if use_cmd:
user_fields = _CURSOR_DOC_FIELDS
legacy_response = False
else:
user_fields = None
legacy_response = True
docs = unpack_res(
reply,
operation.cursor_id,
operation.codec_options,
legacy_response=legacy_response,
user_fields=user_fields,
)
if use_cmd:
first = docs[0]
await operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type]
_check_command_response(first, conn.max_wire_version)
except Exception as exc:
duration = datetime.now() - start
if isinstance(exc, (NotPrimaryError, OperationFailure)):
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.FAILED,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=dbn,
requestId=request_id,
operationId=request_id,
driverConnectionId=conn.id,
serverConnectionId=conn.server_connection_id,
serverHost=conn.address[0],
serverPort=conn.address[1],
serviceId=conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if publish:
assert listeners is not None
listeners.publish_command_failure(
duration,
failure,
operation.name,
request_id,
conn.address,
conn.server_connection_id,
service_id=conn.service_id,
database_name=dbn,
)
raise
duration = datetime.now() - start
# Must publish in find / getMore / explain command response
# format.
if use_cmd:
res = docs[0]
elif operation.name == "explain":
res = docs[0] if docs else {}
else:
res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} # type: ignore[union-attr]
if operation.name == "find":
res["cursor"]["firstBatch"] = docs
else:
res["cursor"]["nextBatch"] = docs
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.SUCCEEDED,
durationMS=duration,
reply=res,
commandName=next(iter(cmd)),
databaseName=dbn,
requestId=request_id,
operationId=request_id,
driverConnectionId=conn.id,
serverConnectionId=conn.server_connection_id,
serverHost=conn.address[0],
serverPort=conn.address[1],
serviceId=conn.service_id,
)
if publish:
assert listeners is not None
listeners.publish_command_success(
duration,
res,
operation.name,
request_id,
conn.address,
conn.server_connection_id,
service_id=conn.service_id,
database_name=dbn,
)
# Decrypt response.
client = operation.client # type: ignore[assignment]
if client and client._encrypter:
if use_cmd:
decrypted = await client._encrypter.decrypt(reply.raw_command_response())
docs = _decode_all_selective(decrypted, operation.codec_options, user_fields)
response: Response
if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type]
conn.pin_cursor()
if isinstance(reply, _OpMsg):
# In OP_MSG, the server keeps sending only if the
# more_to_come flag is set.
more_to_come = reply.more_to_come
else:
# In OP_REPLY, the server keeps sending until cursor_id is 0.
more_to_come = bool(operation.exhaust and reply.cursor_id)
if operation.conn_mgr:
operation.conn_mgr.update_exhaust(more_to_come)
response = PinnedResponse(
data=reply,
address=self._description.address,
conn=conn,
duration=duration,
request_id=request_id,
from_command=use_cmd,
docs=docs,
more_to_come=more_to_come,
)
else:
response = Response(
data=reply,
address=self._description.address,
duration=duration,
request_id=request_id,
from_command=use_cmd,
docs=docs,
)
return response
async def checkout(
self, handler: Optional[_MongoClientErrorHandler] = None
) -> AsyncContextManager[AsyncConnection]:
return self.pool.checkout(handler)
@property
def description(self) -> ServerDescription:
return self._description
@description.setter
def description(self, server_description: ServerDescription) -> None:
assert server_description.address == self._description.address
self._description = server_description
@property
def pool(self) -> Pool:
return self._pool
def _split_message(
self, message: Union[tuple[int, Any], tuple[int, Any, int]]
) -> tuple[int, Any, int]:
"""Return request_id, data, max_doc_size.
:param message: (request_id, data, max_doc_size) or (request_id, data)
"""
if len(message) == 3:
return message # type: ignore[return-value]
else:
# get_more and kill_cursors messages don't include BSON documents.
request_id, data = message # type: ignore[misc]
return request_id, data, 0
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self._description!r}>"

View File

@@ -0,0 +1,172 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Represent MongoClient's configuration."""
from __future__ import annotations
import threading
import traceback
from typing import Any, Collection, Optional, Type, Union
from bson.objectid import ObjectId
from pymongo import common
from pymongo.asynchronous import monitor, pool
from pymongo.asynchronous.pool import Pool
from pymongo.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT
from pymongo.errors import ConfigurationError
from pymongo.pool_options import PoolOptions
from pymongo.server_description import ServerDescription
from pymongo.topology_description import TOPOLOGY_TYPE, _ServerSelector
_IS_SYNC = False
class TopologySettings:
def __init__(
self,
seeds: Optional[Collection[tuple[str, int]]] = None,
replica_set_name: Optional[str] = None,
pool_class: Optional[Type[Pool]] = None,
pool_options: Optional[PoolOptions] = None,
monitor_class: Optional[Type[monitor.Monitor]] = None,
condition_class: Optional[Type[threading.Condition]] = None,
local_threshold_ms: int = LOCAL_THRESHOLD_MS,
server_selection_timeout: int = SERVER_SELECTION_TIMEOUT,
heartbeat_frequency: int = common.HEARTBEAT_FREQUENCY,
server_selector: Optional[_ServerSelector] = None,
fqdn: Optional[str] = None,
direct_connection: Optional[bool] = False,
load_balanced: Optional[bool] = None,
srv_service_name: str = common.SRV_SERVICE_NAME,
srv_max_hosts: int = 0,
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
):
"""Represent MongoClient's configuration.
Take a list of (host, port) pairs and optional replica set name.
"""
if heartbeat_frequency < common.MIN_HEARTBEAT_INTERVAL:
raise ConfigurationError(
"heartbeatFrequencyMS cannot be less than %d"
% (common.MIN_HEARTBEAT_INTERVAL * 1000,)
)
self._seeds: Collection[tuple[str, int]] = seeds or [("localhost", 27017)]
self._replica_set_name = replica_set_name
self._pool_class: Type[Pool] = pool_class or pool.Pool
self._pool_options: PoolOptions = pool_options or PoolOptions()
self._monitor_class: Type[monitor.Monitor] = monitor_class or monitor.Monitor
self._condition_class: Type[threading.Condition] = condition_class or threading.Condition
self._local_threshold_ms = local_threshold_ms
self._server_selection_timeout = server_selection_timeout
self._server_selector = server_selector
self._fqdn = fqdn
self._heartbeat_frequency = heartbeat_frequency
self._direct = direct_connection
self._load_balanced = load_balanced
self._srv_service_name = srv_service_name
self._srv_max_hosts = srv_max_hosts or 0
self._server_monitoring_mode = server_monitoring_mode
self._topology_id = ObjectId()
# Store the allocation traceback to catch unclosed clients in the
# test suite.
self._stack = "".join(traceback.format_stack()[:-2])
@property
def seeds(self) -> Collection[tuple[str, int]]:
"""List of server addresses."""
return self._seeds
@property
def replica_set_name(self) -> Optional[str]:
return self._replica_set_name
@property
def pool_class(self) -> Type[Pool]:
return self._pool_class
@property
def pool_options(self) -> PoolOptions:
return self._pool_options
@property
def monitor_class(self) -> Type[monitor.Monitor]:
return self._monitor_class
@property
def condition_class(self) -> Type[threading.Condition]:
return self._condition_class
@property
def local_threshold_ms(self) -> int:
return self._local_threshold_ms
@property
def server_selection_timeout(self) -> int:
return self._server_selection_timeout
@property
def server_selector(self) -> Optional[_ServerSelector]:
return self._server_selector
@property
def heartbeat_frequency(self) -> int:
return self._heartbeat_frequency
@property
def fqdn(self) -> Optional[str]:
return self._fqdn
@property
def direct(self) -> Optional[bool]:
"""Connect directly to a single server, or use a set of servers?
True if there is one seed and no replica_set_name.
"""
return self._direct
@property
def load_balanced(self) -> Optional[bool]:
"""True if the client was configured to connect to a load balancer."""
return self._load_balanced
@property
def srv_service_name(self) -> str:
"""The srvServiceName."""
return self._srv_service_name
@property
def srv_max_hosts(self) -> int:
"""The srvMaxHosts."""
return self._srv_max_hosts
@property
def server_monitoring_mode(self) -> str:
"""The serverMonitoringMode."""
return self._server_monitoring_mode
def get_topology_type(self) -> int:
if self.load_balanced:
return TOPOLOGY_TYPE.LoadBalanced
elif self.direct:
return TOPOLOGY_TYPE.Single
elif self.replica_set_name is not None:
return TOPOLOGY_TYPE.ReplicaSetNoPrimary
else:
return TOPOLOGY_TYPE.Unknown
def get_server_descriptions(self) -> dict[Union[tuple[str, int], Any], ServerDescription]:
"""Initial dict of (address, ServerDescription) for all seeds."""
return {address: ServerDescription(address) for address in self.seeds}

File diff suppressed because it is too large Load Diff