# Copyright 2024-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """The client-level bulk write operations interface. .. versionadded:: 4.9 """ from __future__ import annotations import copy import datetime import logging from collections.abc import MutableMapping from itertools import islice from typing import ( TYPE_CHECKING, Any, Mapping, Optional, Type, Union, ) from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from pymongo import _csot, common from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.database import AsyncDatabase from pymongo.asynchronous.helpers import _handle_reauth if TYPE_CHECKING: from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.asynchronous.pool import AsyncConnection from pymongo._client_bulk_shared import ( _merge_command, _throw_client_bulk_write_exception, ) from pymongo.common import ( validate_is_document_type, validate_ok_for_replace, validate_ok_for_update, ) from pymongo.errors import ( ConfigurationError, ConnectionFailure, InvalidOperation, NotPrimaryError, OperationFailure, WaitQueueTimeoutError, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _ClientBulkWriteContext, _convert_client_bulk_exception, _convert_exception, _convert_write_result, _randint, ) from pymongo.read_preferences import ReadPreference from pymongo.results import ( ClientBulkWriteResult, DeleteResult, InsertOneResult, UpdateResult, ) from pymongo.typings import _DocumentOut, _Pipeline from pymongo.write_concern import WriteConcern _IS_SYNC = False class _AsyncClientBulk: """The private guts of the client-level bulk write API.""" def __init__( self, client: AsyncMongoClient, write_concern: WriteConcern, ordered: bool = True, bypass_document_validation: Optional[bool] = None, comment: Optional[str] = None, let: Optional[Any] = None, verbose_results: bool = False, ) -> None: """Initialize a _AsyncClientBulk instance.""" self.client = client self.write_concern = write_concern self.let = let if self.let is not None: common.validate_is_document_type("let", self.let) self.ordered = ordered self.bypass_doc_val = bypass_document_validation self.comment = comment self.verbose_results = verbose_results self.ops: list[tuple[str, Mapping[str, Any]]] = [] self.namespaces: list[str] = [] self.idx_offset: int = 0 self.total_ops: int = 0 self.executed = False self.uses_upsert = False self.uses_collation = False self.uses_array_filters = False self.uses_hint_update = False self.uses_hint_delete = False self.is_retryable = self.client.options.retry_writes self.retrying = False self.started_retryable_write = False @property def bulk_ctx_class(self) -> Type[_ClientBulkWriteContext]: return _ClientBulkWriteContext def add_insert(self, namespace: str, document: _DocumentOut) -> None: """Add an insert document to the list of ops.""" validate_is_document_type("document", document) # Generate ObjectId client side. if not (isinstance(document, RawBSONDocument) or "_id" in document): document["_id"] = ObjectId() cmd = {"insert": -1, "document": document} self.ops.append(("insert", cmd)) self.namespaces.append(namespace) self.total_ops += 1 def add_update( self, namespace: str, selector: Mapping[str, Any], update: Union[Mapping[str, Any], _Pipeline], multi: bool = False, upsert: Optional[bool] = None, collation: Optional[Mapping[str, Any]] = None, array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Union[str, dict[str, Any], None] = None, ) -> None: """Create an update document and add it to the list of ops.""" validate_ok_for_update(update) cmd = { "update": -1, "filter": selector, "updateMods": update, "multi": multi, } if upsert is not None: self.uses_upsert = True cmd["upsert"] = upsert if array_filters is not None: self.uses_array_filters = True cmd["arrayFilters"] = array_filters if hint is not None: self.uses_hint_update = True cmd["hint"] = hint if collation is not None: self.uses_collation = True cmd["collation"] = collation if multi: # A bulk_write containing an update_many is not retryable. self.is_retryable = False self.ops.append(("update", cmd)) self.namespaces.append(namespace) self.total_ops += 1 def add_replace( self, namespace: str, selector: Mapping[str, Any], replacement: Mapping[str, Any], upsert: Optional[bool] = None, collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, ) -> None: """Create a replace document and add it to the list of ops.""" validate_ok_for_replace(replacement) cmd = { "update": -1, "filter": selector, "updateMods": replacement, "multi": False, } if upsert is not None: self.uses_upsert = True cmd["upsert"] = upsert if hint is not None: self.uses_hint_update = True cmd["hint"] = hint if collation is not None: self.uses_collation = True cmd["collation"] = collation self.ops.append(("replace", cmd)) self.namespaces.append(namespace) self.total_ops += 1 def add_delete( self, namespace: str, selector: Mapping[str, Any], multi: bool, collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, ) -> None: """Create a delete document and add it to the list of ops.""" cmd = {"delete": -1, "filter": selector, "multi": multi} if hint is not None: self.uses_hint_delete = True cmd["hint"] = hint if collation is not None: self.uses_collation = True cmd["collation"] = collation if multi: # A bulk_write containing an update_many is not retryable. self.is_retryable = False self.ops.append(("delete", cmd)) self.namespaces.append(namespace) self.total_ops += 1 @_handle_reauth async def write_command( self, bwc: _ClientBulkWriteContext, cmd: MutableMapping[str, Any], request_id: int, msg: Union[bytes, dict[str, Any]], op_docs: list[Mapping[str, Any]], ns_docs: list[Mapping[str, Any]], client: AsyncMongoClient, ) -> dict[str, Any]: """A proxy for AsyncConnection.write_command that handles event publishing.""" cmd["ops"] = op_docs cmd["nsInfo"] = ns_docs if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.STARTED, command=cmd, commandName=next(iter(cmd)), databaseName=bwc.db_name, requestId=request_id, operationId=request_id, driverConnectionId=bwc.conn.id, serverConnectionId=bwc.conn.server_connection_id, serverHost=bwc.conn.address[0], serverPort=bwc.conn.address[1], serviceId=bwc.conn.service_id, ) if bwc.publish: bwc._start(cmd, request_id, op_docs, ns_docs) try: reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] duration = datetime.datetime.now() - bwc.start_time if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.SUCCEEDED, durationMS=duration, reply=reply, commandName=next(iter(cmd)), databaseName=bwc.db_name, requestId=request_id, operationId=request_id, driverConnectionId=bwc.conn.id, serverConnectionId=bwc.conn.server_connection_id, serverHost=bwc.conn.address[0], serverPort=bwc.conn.address[1], serviceId=bwc.conn.service_id, ) if bwc.publish: bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] # Process the response from the server. await self.client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: duration = datetime.datetime.now() - bwc.start_time if isinstance(exc, (NotPrimaryError, OperationFailure)): failure: _DocumentOut = exc.details # type: ignore[assignment] else: failure = _convert_exception(exc) if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.FAILED, durationMS=duration, failure=failure, commandName=next(iter(cmd)), databaseName=bwc.db_name, requestId=request_id, operationId=request_id, driverConnectionId=bwc.conn.id, serverConnectionId=bwc.conn.server_connection_id, serverHost=bwc.conn.address[0], serverPort=bwc.conn.address[1], serviceId=bwc.conn.service_id, isServerSideError=isinstance(exc, OperationFailure), ) if bwc.publish: bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} # Process the response from the server. if isinstance(exc, OperationFailure): await self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type] else: await self.client._process_response({}, bwc.session) # type: ignore[arg-type] return reply # type: ignore[return-value] async def unack_write( self, bwc: _ClientBulkWriteContext, cmd: MutableMapping[str, Any], request_id: int, msg: bytes, op_docs: list[Mapping[str, Any]], ns_docs: list[Mapping[str, Any]], client: AsyncMongoClient, ) -> Optional[Mapping[str, Any]]: """A proxy for AsyncConnection.unack_write that handles event publishing.""" if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.STARTED, command=cmd, commandName=next(iter(cmd)), databaseName=bwc.db_name, requestId=request_id, operationId=request_id, driverConnectionId=bwc.conn.id, serverConnectionId=bwc.conn.server_connection_id, serverHost=bwc.conn.address[0], serverPort=bwc.conn.address[1], serviceId=bwc.conn.service_id, ) if bwc.publish: cmd = bwc._start(cmd, request_id, op_docs, ns_docs) try: result = await bwc.conn.unack_write(msg, bwc.max_bson_size) # type: ignore[func-returns-value, misc, override] duration = datetime.datetime.now() - bwc.start_time if result is not None: reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] else: # Comply with APM spec. reply = {"ok": 1} if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.SUCCEEDED, durationMS=duration, reply=reply, commandName=next(iter(cmd)), databaseName=bwc.db_name, requestId=request_id, operationId=request_id, driverConnectionId=bwc.conn.id, serverConnectionId=bwc.conn.server_connection_id, serverHost=bwc.conn.address[0], serverPort=bwc.conn.address[1], serviceId=bwc.conn.service_id, ) if bwc.publish: bwc._succeed(request_id, reply, duration) except Exception as exc: duration = datetime.datetime.now() - bwc.start_time if isinstance(exc, OperationFailure): failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] elif isinstance(exc, NotPrimaryError): failure = exc.details # type: ignore[assignment] else: failure = _convert_exception(exc) if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.FAILED, durationMS=duration, failure=failure, commandName=next(iter(cmd)), databaseName=bwc.db_name, requestId=request_id, operationId=request_id, driverConnectionId=bwc.conn.id, serverConnectionId=bwc.conn.server_connection_id, serverHost=bwc.conn.address[0], serverPort=bwc.conn.address[1], serviceId=bwc.conn.service_id, isServerSideError=isinstance(exc, OperationFailure), ) if bwc.publish: assert bwc.start_time is not None bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} return reply async def _execute_batch_unack( self, bwc: _ClientBulkWriteContext, cmd: dict[str, Any], ops: list[tuple[str, Mapping[str, Any]]], namespaces: list[str], ) -> tuple[list[Mapping[str, Any]], list[Mapping[str, Any]]]: """Executes a batch of bulkWrite server commands (unack).""" request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces) await self.unack_write(bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client) # type: ignore[arg-type] return to_send_ops, to_send_ns async def _execute_batch( self, bwc: _ClientBulkWriteContext, cmd: dict[str, Any], ops: list[tuple[str, Mapping[str, Any]]], namespaces: list[str], ) -> tuple[dict[str, Any], list[Mapping[str, Any]], list[Mapping[str, Any]]]: """Executes a batch of bulkWrite server commands (ack).""" request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces) result = await self.write_command( bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client ) # type: ignore[arg-type] return result, to_send_ops, to_send_ns # type: ignore[return-value] async def _process_results_cursor( self, full_result: MutableMapping[str, Any], result: MutableMapping[str, Any], conn: AsyncConnection, session: Optional[AsyncClientSession], ) -> None: """Internal helper for processing the server reply command cursor.""" if result.get("cursor"): coll = AsyncCollection( database=AsyncDatabase(self.client, "admin"), name="$cmd.bulkWrite", ) cmd_cursor = AsyncCommandCursor( coll, result["cursor"], conn.address, session=session, explicit_session=session is not None, comment=self.comment, ) await cmd_cursor._maybe_pin_connection(conn) # Iterate the cursor to get individual write results. try: async for doc in cmd_cursor: original_index = doc["idx"] + self.idx_offset op_type, op = self.ops[original_index] if not doc["ok"]: result["writeErrors"].append(doc) if self.ordered: return # Record individual write result. if doc["ok"] and self.verbose_results: if op_type == "insert": inserted_id = op["document"]["_id"] res = InsertOneResult(inserted_id, acknowledged=True) # type: ignore[assignment] if op_type in ["update", "replace"]: op_type = "update" res = UpdateResult(doc, acknowledged=True, in_client_bulk=True) # type: ignore[assignment] if op_type == "delete": res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment] full_result[f"{op_type}Results"][original_index] = res except Exception as exc: # Attempt to close the cursor, then raise top-level error. if cmd_cursor.alive: await cmd_cursor.close() result["error"] = _convert_client_bulk_exception(exc) async def _execute_command( self, write_concern: WriteConcern, session: Optional[AsyncClientSession], conn: AsyncConnection, op_id: int, retryable: bool, full_result: MutableMapping[str, Any], final_write_concern: Optional[WriteConcern] = None, ) -> None: """Internal helper for executing batches of bulkWrite commands.""" db_name = "admin" cmd_name = "bulkWrite" listeners = self.client._event_listeners # AsyncConnection.command validates the session, but we use # AsyncConnection.write_command conn.validate_session(self.client, session) bwc = self.bulk_ctx_class( db_name, cmd_name, conn, op_id, listeners, # type: ignore[arg-type] session, self.client.codec_options, ) while self.idx_offset < self.total_ops: # If this is the last possible batch, use the # final write concern. if self.total_ops - self.idx_offset <= bwc.max_write_batch_size: write_concern = final_write_concern or write_concern # Construct the server command, specifying the relevant options. cmd = {"bulkWrite": 1} cmd["errorsOnly"] = not self.verbose_results cmd["ordered"] = self.ordered # type: ignore[assignment] not_in_transaction = session and not session.in_transaction if not_in_transaction or not session: _csot.apply_write_concern(cmd, write_concern) if self.bypass_doc_val is not None: cmd["bypassDocumentValidation"] = self.bypass_doc_val if self.comment: cmd["comment"] = self.comment # type: ignore[assignment] if self.let: cmd["let"] = self.let if session: # Start a new retryable write unless one was already # started for this command. if retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, self.client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. conn.apply_timeout(self.client, cmd) ops = islice(self.ops, self.idx_offset, None) namespaces = islice(self.namespaces, self.idx_offset, None) # Run as many ops as possible in one server command. if write_concern.acknowledged: raw_result, to_send_ops, _ = await self._execute_batch(bwc, cmd, ops, namespaces) # type: ignore[arg-type] result = raw_result # Top-level server/network error. if result.get("error"): error = result["error"] retryable_top_level_error = ( hasattr(error, "details") and isinstance(error.details, dict) and error.details.get("code", 0) in _RETRYABLE_ERROR_CODES ) retryable_network_error = isinstance( error, ConnectionFailure ) and not isinstance(error, (NotPrimaryError, WaitQueueTimeoutError)) # Synthesize the full bulk result without modifying the # current one because this write operation may be retried. if retryable and (retryable_top_level_error or retryable_network_error): full = copy.deepcopy(full_result) _merge_command(self.ops, self.idx_offset, full, result) _throw_client_bulk_write_exception(full, self.verbose_results) else: _merge_command(self.ops, self.idx_offset, full_result, result) _throw_client_bulk_write_exception(full_result, self.verbose_results) result["error"] = None result["writeErrors"] = [] if result.get("nErrors", 0) < len(to_send_ops): full_result["anySuccessful"] = True # Top-level command error. if not result["ok"]: result["error"] = raw_result _merge_command(self.ops, self.idx_offset, full_result, result) break if retryable: # Retryable writeConcernErrors halt the execution of this batch. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: # Synthesize the full bulk result without modifying the # current one because this write operation may be retried. full = copy.deepcopy(full_result) _merge_command(self.ops, self.idx_offset, full, result) _throw_client_bulk_write_exception(full, self.verbose_results) # Process the server reply as a command cursor. await self._process_results_cursor(full_result, result, conn, session) # Merge this batch's results with the full results. _merge_command(self.ops, self.idx_offset, full_result, result) # We're no longer in a retry once a command succeeds. self.retrying = False self.started_retryable_write = False else: to_send_ops, _ = await self._execute_batch_unack(bwc, cmd, ops, namespaces) # type: ignore[arg-type] self.idx_offset += len(to_send_ops) # We halt execution if we hit a top-level error, # or an individual error in an ordered bulk write. if full_result["error"] or (self.ordered and full_result["writeErrors"]): break async def execute_command( self, session: Optional[AsyncClientSession], operation: str, ) -> MutableMapping[str, Any]: """Execute commands with w=1 WriteConcern.""" full_result: MutableMapping[str, Any] = { "anySuccessful": False, "error": None, "writeErrors": [], "writeConcernErrors": [], "nInserted": 0, "nUpserted": 0, "nMatched": 0, "nModified": 0, "nDeleted": 0, "insertResults": {}, "updateResults": {}, "deleteResults": {}, } op_id = _randint() async def retryable_bulk( session: Optional[AsyncClientSession], conn: AsyncConnection, retryable: bool, ) -> None: if conn.max_wire_version < 25: raise InvalidOperation( "MongoClient.bulk_write requires MongoDB server version 8.0+." ) await self._execute_command( self.write_concern, session, conn, op_id, retryable, full_result, ) await self.client._retryable_write( self.is_retryable, retryable_bulk, session, operation, bulk=self, operation_id=op_id, ) if full_result["error"] or full_result["writeErrors"] or full_result["writeConcernErrors"]: _throw_client_bulk_write_exception(full_result, self.verbose_results) return full_result async def execute_command_unack_unordered( self, conn: AsyncConnection, ) -> None: """Execute commands with OP_MSG and w=0 writeConcern, unordered.""" db_name = "admin" cmd_name = "bulkWrite" listeners = self.client._event_listeners op_id = _randint() bwc = self.bulk_ctx_class( db_name, cmd_name, conn, op_id, listeners, # type: ignore[arg-type] None, self.client.codec_options, ) while self.idx_offset < self.total_ops: # Construct the server command, specifying the relevant options. cmd = {"bulkWrite": 1} cmd["errorsOnly"] = not self.verbose_results cmd["ordered"] = self.ordered # type: ignore[assignment] if self.bypass_doc_val is not None: cmd["bypassDocumentValidation"] = self.bypass_doc_val cmd["writeConcern"] = {"w": 0} # type: ignore[assignment] if self.comment: cmd["comment"] = self.comment # type: ignore[assignment] if self.let: cmd["let"] = self.let conn.add_server_api(cmd) ops = islice(self.ops, self.idx_offset, None) namespaces = islice(self.namespaces, self.idx_offset, None) # Run as many ops as possible in one server command. to_send_ops, _ = await self._execute_batch_unack(bwc, cmd, ops, namespaces) # type: ignore[arg-type] self.idx_offset += len(to_send_ops) async def execute_command_unack_ordered( self, conn: AsyncConnection, ) -> None: """Execute commands with OP_MSG and w=0 WriteConcern, ordered.""" full_result: MutableMapping[str, Any] = { "anySuccessful": False, "error": None, "writeErrors": [], "writeConcernErrors": [], "nInserted": 0, "nUpserted": 0, "nMatched": 0, "nModified": 0, "nDeleted": 0, "insertResults": {}, "updateResults": {}, "deleteResults": {}, } # Ordered bulk writes have to be acknowledged so that we stop # processing at the first error, even when the application # specified unacknowledged writeConcern. initial_write_concern = WriteConcern() op_id = _randint() try: await self._execute_command( initial_write_concern, None, conn, op_id, False, full_result, self.write_concern, ) except OperationFailure: pass async def execute_no_results( self, conn: AsyncConnection, ) -> None: """Execute all operations, returning no results (w=0).""" if self.uses_collation: raise ConfigurationError("Collation is unsupported for unacknowledged writes.") if self.uses_array_filters: raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.") # Cannot have both unacknowledged writes and bypass document validation. if self.bypass_doc_val is not None: raise OperationFailure( "Cannot set bypass_document_validation with unacknowledged write concern" ) if self.ordered: return await self.execute_command_unack_ordered(conn) return await self.execute_command_unack_unordered(conn) async def execute( self, session: Optional[AsyncClientSession], operation: str, ) -> Any: """Execute operations.""" if not self.ops: raise InvalidOperation("No operations to execute") if self.executed: raise InvalidOperation("Bulk operations can only be executed once.") self.executed = True session = _validate_session_write_concern(session, self.write_concern) if not self.write_concern.acknowledged: async with await self.client._conn_for_writes(session, operation) as connection: if connection.max_wire_version < 25: raise InvalidOperation( "MongoClient.bulk_write requires MongoDB server version 8.0+." ) await self.execute_no_results(connection) return ClientBulkWriteResult(None, False, False) # type: ignore[arg-type] result = await self.execute_command(session, operation) return ClientBulkWriteResult( result, self.write_concern.acknowledged, self.verbose_results, )