1897 lines
60 KiB
Python
1897 lines
60 KiB
Python
# Copyright 2009-present MongoDB, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Tools for creating `messages
|
|
<https://www.mongodb.com/docs/manual/reference/mongodb-wire-protocol/>`_ to be sent to
|
|
MongoDB.
|
|
|
|
.. note:: This module is for internal use and is generally not needed by
|
|
application developers.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import datetime
|
|
import random
|
|
import struct
|
|
from io import BytesIO as _BytesIO
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Callable,
|
|
Iterable,
|
|
Mapping,
|
|
MutableMapping,
|
|
NoReturn,
|
|
Optional,
|
|
Union,
|
|
)
|
|
|
|
import bson
|
|
from bson import CodecOptions, _dict_to_bson, _make_c_string
|
|
from bson.int64 import Int64
|
|
from bson.raw_bson import (
|
|
_RAW_ARRAY_BSON_OPTIONS,
|
|
DEFAULT_RAW_BSON_OPTIONS,
|
|
RawBSONDocument,
|
|
_inflate_bson,
|
|
)
|
|
from pymongo.hello import HelloCompat
|
|
from pymongo.monitoring import _EventListeners
|
|
|
|
try:
|
|
from pymongo import _cmessage # type: ignore[attr-defined]
|
|
|
|
_use_c = True
|
|
except ImportError:
|
|
_use_c = False
|
|
from pymongo.errors import (
|
|
ConfigurationError,
|
|
CursorNotFound,
|
|
DocumentTooLarge,
|
|
ExecutionTimeout,
|
|
InvalidOperation,
|
|
NotPrimaryError,
|
|
OperationFailure,
|
|
ProtocolError,
|
|
)
|
|
from pymongo.read_preferences import ReadPreference, _ServerMode
|
|
|
|
if TYPE_CHECKING:
|
|
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
|
|
from pymongo.read_concern import ReadConcern
|
|
from pymongo.typings import (
|
|
_Address,
|
|
_AgnosticClientSession,
|
|
_AgnosticConnection,
|
|
_AgnosticMongoClient,
|
|
_DocumentOut,
|
|
)
|
|
|
|
|
|
MAX_INT32 = 2147483647
|
|
MIN_INT32 = -2147483648
|
|
|
|
# Overhead allowed for encoded command documents.
|
|
_COMMAND_OVERHEAD = 16382
|
|
|
|
_INSERT = 0
|
|
_UPDATE = 1
|
|
_DELETE = 2
|
|
|
|
_EMPTY = b""
|
|
_BSONOBJ = b"\x03"
|
|
_ZERO_8 = b"\x00"
|
|
_ZERO_16 = b"\x00\x00"
|
|
_ZERO_32 = b"\x00\x00\x00\x00"
|
|
_ZERO_64 = b"\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
_SKIPLIM = b"\x00\x00\x00\x00\xff\xff\xff\xff"
|
|
_OP_MAP = {
|
|
_INSERT: b"\x04documents\x00\x00\x00\x00\x00",
|
|
_UPDATE: b"\x04updates\x00\x00\x00\x00\x00",
|
|
_DELETE: b"\x04deletes\x00\x00\x00\x00\x00",
|
|
}
|
|
_FIELD_MAP = {
|
|
"insert": "documents",
|
|
"update": "updates",
|
|
"delete": "deletes",
|
|
"bulkWrite": "bulkWrite",
|
|
}
|
|
|
|
_UNICODE_REPLACE_CODEC_OPTIONS: CodecOptions[Mapping[str, Any]] = CodecOptions(
|
|
unicode_decode_error_handler="replace"
|
|
)
|
|
|
|
|
|
def _randint() -> int:
|
|
"""Generate a pseudo random 32 bit integer."""
|
|
return random.randint(MIN_INT32, MAX_INT32) # noqa: S311
|
|
|
|
|
|
def _maybe_add_read_preference(
|
|
spec: MutableMapping[str, Any], read_preference: _ServerMode
|
|
) -> MutableMapping[str, Any]:
|
|
"""Add $readPreference to spec when appropriate."""
|
|
mode = read_preference.mode
|
|
document = read_preference.document
|
|
# Only add $readPreference if it's something other than primary to avoid
|
|
# problems with mongos versions that don't support read preferences. Also,
|
|
# for maximum backwards compatibility, don't add $readPreference for
|
|
# secondaryPreferred unless tags or maxStalenessSeconds are in use (setting
|
|
# the secondaryOkay bit has the same effect).
|
|
if mode and (mode != ReadPreference.SECONDARY_PREFERRED.mode or len(document) > 1):
|
|
if "$query" not in spec:
|
|
spec = {"$query": spec}
|
|
spec["$readPreference"] = document
|
|
return spec
|
|
|
|
|
|
def _convert_exception(exception: Exception) -> dict[str, Any]:
|
|
"""Convert an Exception into a failure document for publishing."""
|
|
return {"errmsg": str(exception), "errtype": exception.__class__.__name__}
|
|
|
|
|
|
def _convert_client_bulk_exception(exception: Exception) -> dict[str, Any]:
|
|
"""Convert an Exception into a failure document for publishing,
|
|
for use in client-level bulk write API.
|
|
"""
|
|
return {
|
|
"errmsg": str(exception),
|
|
"code": exception.code, # type: ignore[attr-defined]
|
|
"errtype": exception.__class__.__name__,
|
|
}
|
|
|
|
|
|
def _convert_write_result(
|
|
operation: str, command: Mapping[str, Any], result: Mapping[str, Any]
|
|
) -> dict[str, Any]:
|
|
"""Convert a legacy write result to write command format."""
|
|
# Based on _merge_legacy from bulk.py
|
|
affected = result.get("n", 0)
|
|
res = {"ok": 1, "n": affected}
|
|
errmsg = result.get("errmsg", result.get("err", ""))
|
|
if errmsg:
|
|
# The write was successful on at least the primary so don't return.
|
|
if result.get("wtimeout"):
|
|
res["writeConcernError"] = {"errmsg": errmsg, "code": 64, "errInfo": {"wtimeout": True}}
|
|
else:
|
|
# The write failed.
|
|
error = {"index": 0, "code": result.get("code", 8), "errmsg": errmsg}
|
|
if "errInfo" in result:
|
|
error["errInfo"] = result["errInfo"]
|
|
res["writeErrors"] = [error]
|
|
return res
|
|
if operation == "insert":
|
|
# GLE result for insert is always 0 in most MongoDB versions.
|
|
res["n"] = len(command["documents"])
|
|
elif operation == "update":
|
|
if "upserted" in result:
|
|
res["upserted"] = [{"index": 0, "_id": result["upserted"]}]
|
|
# Versions of MongoDB before 2.6 don't return the _id for an
|
|
# upsert if _id is not an ObjectId.
|
|
elif result.get("updatedExisting") is False and affected == 1:
|
|
# If _id is in both the update document *and* the query spec
|
|
# the update document _id takes precedence.
|
|
update = command["updates"][0]
|
|
_id = update["u"].get("_id", update["q"].get("_id"))
|
|
res["upserted"] = [{"index": 0, "_id": _id}]
|
|
return res
|
|
|
|
|
|
_OPTIONS = {
|
|
"tailable": 2,
|
|
"oplogReplay": 8,
|
|
"noCursorTimeout": 16,
|
|
"awaitData": 32,
|
|
"allowPartialResults": 128,
|
|
}
|
|
|
|
|
|
_MODIFIERS = {
|
|
"$query": "filter",
|
|
"$orderby": "sort",
|
|
"$hint": "hint",
|
|
"$comment": "comment",
|
|
"$maxScan": "maxScan",
|
|
"$maxTimeMS": "maxTimeMS",
|
|
"$max": "max",
|
|
"$min": "min",
|
|
"$returnKey": "returnKey",
|
|
"$showRecordId": "showRecordId",
|
|
"$showDiskLoc": "showRecordId", # <= MongoDb 3.0
|
|
"$snapshot": "snapshot",
|
|
}
|
|
|
|
|
|
def _gen_find_command(
|
|
coll: str,
|
|
spec: Mapping[str, Any],
|
|
projection: Optional[Union[Mapping[str, Any], Iterable[str]]],
|
|
skip: int,
|
|
limit: int,
|
|
batch_size: Optional[int],
|
|
options: Optional[int],
|
|
read_concern: ReadConcern,
|
|
collation: Optional[Mapping[str, Any]] = None,
|
|
session: Optional[_AgnosticClientSession] = None,
|
|
allow_disk_use: Optional[bool] = None,
|
|
) -> dict[str, Any]:
|
|
"""Generate a find command document."""
|
|
cmd: dict[str, Any] = {"find": coll}
|
|
if "$query" in spec:
|
|
cmd.update(
|
|
[
|
|
(_MODIFIERS[key], val) if key in _MODIFIERS else (key, val)
|
|
for key, val in spec.items()
|
|
]
|
|
)
|
|
if "$explain" in cmd:
|
|
cmd.pop("$explain")
|
|
if "$readPreference" in cmd:
|
|
cmd.pop("$readPreference")
|
|
else:
|
|
cmd["filter"] = spec
|
|
|
|
if projection:
|
|
cmd["projection"] = projection
|
|
if skip:
|
|
cmd["skip"] = skip
|
|
if limit:
|
|
cmd["limit"] = abs(limit)
|
|
if limit < 0:
|
|
cmd["singleBatch"] = True
|
|
if batch_size:
|
|
cmd["batchSize"] = batch_size
|
|
if read_concern.level and not (session and session.in_transaction):
|
|
cmd["readConcern"] = read_concern.document
|
|
if collation:
|
|
cmd["collation"] = collation
|
|
if allow_disk_use is not None:
|
|
cmd["allowDiskUse"] = allow_disk_use
|
|
if options:
|
|
cmd.update([(opt, True) for opt, val in _OPTIONS.items() if options & val])
|
|
|
|
return cmd
|
|
|
|
|
|
def _gen_get_more_command(
|
|
cursor_id: Optional[int],
|
|
coll: str,
|
|
batch_size: Optional[int],
|
|
max_await_time_ms: Optional[int],
|
|
comment: Optional[Any],
|
|
conn: _AgnosticConnection,
|
|
) -> dict[str, Any]:
|
|
"""Generate a getMore command document."""
|
|
cmd: dict[str, Any] = {"getMore": cursor_id, "collection": coll}
|
|
if batch_size:
|
|
cmd["batchSize"] = batch_size
|
|
if max_await_time_ms is not None:
|
|
cmd["maxTimeMS"] = max_await_time_ms
|
|
if comment is not None and conn.max_wire_version >= 9:
|
|
cmd["comment"] = comment
|
|
return cmd
|
|
|
|
|
|
_pack_compression_header = struct.Struct("<iiiiiiB").pack
|
|
_COMPRESSION_HEADER_SIZE = 25
|
|
|
|
|
|
def _compress(
|
|
operation: int, data: bytes, ctx: Union[SnappyContext, ZlibContext, ZstdContext]
|
|
) -> tuple[int, bytes]:
|
|
"""Takes message data, compresses it, and adds an OP_COMPRESSED header."""
|
|
compressed = ctx.compress(data)
|
|
request_id = _randint()
|
|
|
|
header = _pack_compression_header(
|
|
_COMPRESSION_HEADER_SIZE + len(compressed), # Total message length
|
|
request_id, # Request id
|
|
0, # responseTo
|
|
2012, # operation id
|
|
operation, # original operation id
|
|
len(data), # uncompressed message length
|
|
ctx.compressor_id,
|
|
) # compressor id
|
|
return request_id, header + compressed
|
|
|
|
|
|
_pack_header = struct.Struct("<iiii").pack
|
|
|
|
|
|
def __pack_message(operation: int, data: bytes) -> tuple[int, bytes]:
|
|
"""Takes message data and adds a message header based on the operation.
|
|
|
|
Returns the resultant message string.
|
|
"""
|
|
rid = _randint()
|
|
message = _pack_header(16 + len(data), rid, 0, operation)
|
|
return rid, message + data
|
|
|
|
|
|
_pack_int = struct.Struct("<i").pack
|
|
_pack_op_msg_flags_type = struct.Struct("<IB").pack
|
|
_pack_byte = struct.Struct("<B").pack
|
|
|
|
|
|
def _op_msg_no_header(
|
|
flags: int,
|
|
command: Mapping[str, Any],
|
|
identifier: str,
|
|
docs: Optional[list[Mapping[str, Any]]],
|
|
opts: CodecOptions,
|
|
) -> tuple[bytes, int, int]:
|
|
"""Get a OP_MSG message.
|
|
|
|
Note: this method handles multiple documents in a type one payload but
|
|
it does not perform batch splitting and the total message size is
|
|
only checked *after* generating the entire message.
|
|
"""
|
|
# Encode the command document in payload 0 without checking keys.
|
|
encoded = _dict_to_bson(command, False, opts)
|
|
flags_type = _pack_op_msg_flags_type(flags, 0)
|
|
total_size = len(encoded)
|
|
max_doc_size = 0
|
|
if identifier and docs is not None:
|
|
type_one = _pack_byte(1)
|
|
cstring = _make_c_string(identifier)
|
|
encoded_docs = [_dict_to_bson(doc, False, opts) for doc in docs]
|
|
size = len(cstring) + sum(len(doc) for doc in encoded_docs) + 4
|
|
encoded_size = _pack_int(size)
|
|
total_size += size
|
|
max_doc_size = max(len(doc) for doc in encoded_docs)
|
|
data = [flags_type, encoded, type_one, encoded_size, cstring, *encoded_docs]
|
|
else:
|
|
data = [flags_type, encoded]
|
|
return b"".join(data), total_size, max_doc_size
|
|
|
|
|
|
def _op_msg_compressed(
|
|
flags: int,
|
|
command: Mapping[str, Any],
|
|
identifier: str,
|
|
docs: Optional[list[Mapping[str, Any]]],
|
|
opts: CodecOptions,
|
|
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
|
|
) -> tuple[int, bytes, int, int]:
|
|
"""Internal OP_MSG message helper."""
|
|
msg, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts)
|
|
rid, msg = _compress(2013, msg, ctx)
|
|
return rid, msg, total_size, max_bson_size
|
|
|
|
|
|
def _op_msg_uncompressed(
|
|
flags: int,
|
|
command: Mapping[str, Any],
|
|
identifier: str,
|
|
docs: Optional[list[Mapping[str, Any]]],
|
|
opts: CodecOptions,
|
|
) -> tuple[int, bytes, int, int]:
|
|
"""Internal compressed OP_MSG message helper."""
|
|
data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts)
|
|
request_id, op_message = __pack_message(2013, data)
|
|
return request_id, op_message, total_size, max_bson_size
|
|
|
|
|
|
if _use_c:
|
|
_op_msg_uncompressed = _cmessage._op_msg
|
|
|
|
|
|
def _op_msg(
|
|
flags: int,
|
|
command: MutableMapping[str, Any],
|
|
dbname: str,
|
|
read_preference: Optional[_ServerMode],
|
|
opts: CodecOptions,
|
|
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
|
|
) -> tuple[int, bytes, int, int]:
|
|
"""Get a OP_MSG message."""
|
|
command["$db"] = dbname
|
|
# getMore commands do not send $readPreference.
|
|
if read_preference is not None and "$readPreference" not in command:
|
|
# Only send $readPreference if it's not primary (the default).
|
|
if read_preference.mode:
|
|
command["$readPreference"] = read_preference.document
|
|
name = next(iter(command))
|
|
try:
|
|
identifier = _FIELD_MAP[name]
|
|
docs = command.pop(identifier)
|
|
except KeyError:
|
|
identifier = ""
|
|
docs = None
|
|
try:
|
|
if ctx:
|
|
return _op_msg_compressed(flags, command, identifier, docs, opts, ctx)
|
|
return _op_msg_uncompressed(flags, command, identifier, docs, opts)
|
|
finally:
|
|
# Add the field back to the command.
|
|
if identifier:
|
|
command[identifier] = docs
|
|
|
|
|
|
def _query_impl(
|
|
options: int,
|
|
collection_name: str,
|
|
num_to_skip: int,
|
|
num_to_return: int,
|
|
query: Mapping[str, Any],
|
|
field_selector: Optional[Mapping[str, Any]],
|
|
opts: CodecOptions,
|
|
) -> tuple[bytes, int]:
|
|
"""Get an OP_QUERY message."""
|
|
encoded = _dict_to_bson(query, False, opts)
|
|
if field_selector:
|
|
efs = _dict_to_bson(field_selector, False, opts)
|
|
else:
|
|
efs = b""
|
|
max_bson_size = max(len(encoded), len(efs))
|
|
return (
|
|
b"".join(
|
|
[
|
|
_pack_int(options),
|
|
bson._make_c_string(collection_name),
|
|
_pack_int(num_to_skip),
|
|
_pack_int(num_to_return),
|
|
encoded,
|
|
efs,
|
|
]
|
|
),
|
|
max_bson_size,
|
|
)
|
|
|
|
|
|
def _query_compressed(
|
|
options: int,
|
|
collection_name: str,
|
|
num_to_skip: int,
|
|
num_to_return: int,
|
|
query: Mapping[str, Any],
|
|
field_selector: Optional[Mapping[str, Any]],
|
|
opts: CodecOptions,
|
|
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
|
|
) -> tuple[int, bytes, int]:
|
|
"""Internal compressed query message helper."""
|
|
op_query, max_bson_size = _query_impl(
|
|
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
|
|
)
|
|
rid, msg = _compress(2004, op_query, ctx)
|
|
return rid, msg, max_bson_size
|
|
|
|
|
|
def _query_uncompressed(
|
|
options: int,
|
|
collection_name: str,
|
|
num_to_skip: int,
|
|
num_to_return: int,
|
|
query: Mapping[str, Any],
|
|
field_selector: Optional[Mapping[str, Any]],
|
|
opts: CodecOptions,
|
|
) -> tuple[int, bytes, int]:
|
|
"""Internal query message helper."""
|
|
op_query, max_bson_size = _query_impl(
|
|
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
|
|
)
|
|
rid, msg = __pack_message(2004, op_query)
|
|
return rid, msg, max_bson_size
|
|
|
|
|
|
if _use_c:
|
|
_query_uncompressed = _cmessage._query_message
|
|
|
|
|
|
def _query(
|
|
options: int,
|
|
collection_name: str,
|
|
num_to_skip: int,
|
|
num_to_return: int,
|
|
query: Mapping[str, Any],
|
|
field_selector: Optional[Mapping[str, Any]],
|
|
opts: CodecOptions,
|
|
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
|
|
) -> tuple[int, bytes, int]:
|
|
"""Get a **query** message."""
|
|
if ctx:
|
|
return _query_compressed(
|
|
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx
|
|
)
|
|
return _query_uncompressed(
|
|
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
|
|
)
|
|
|
|
|
|
_pack_long_long = struct.Struct("<q").pack
|
|
|
|
|
|
def _get_more_impl(collection_name: str, num_to_return: int, cursor_id: int) -> bytes:
|
|
"""Get an OP_GET_MORE message."""
|
|
return b"".join(
|
|
[
|
|
_ZERO_32,
|
|
bson._make_c_string(collection_name),
|
|
_pack_int(num_to_return),
|
|
_pack_long_long(cursor_id),
|
|
]
|
|
)
|
|
|
|
|
|
def _get_more_compressed(
|
|
collection_name: str,
|
|
num_to_return: int,
|
|
cursor_id: int,
|
|
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
|
|
) -> tuple[int, bytes]:
|
|
"""Internal compressed getMore message helper."""
|
|
return _compress(2005, _get_more_impl(collection_name, num_to_return, cursor_id), ctx)
|
|
|
|
|
|
def _get_more_uncompressed(
|
|
collection_name: str, num_to_return: int, cursor_id: int
|
|
) -> tuple[int, bytes]:
|
|
"""Internal getMore message helper."""
|
|
return __pack_message(2005, _get_more_impl(collection_name, num_to_return, cursor_id))
|
|
|
|
|
|
if _use_c:
|
|
_get_more_uncompressed = _cmessage._get_more_message
|
|
|
|
|
|
def _get_more(
|
|
collection_name: str,
|
|
num_to_return: int,
|
|
cursor_id: int,
|
|
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
|
|
) -> tuple[int, bytes]:
|
|
"""Get a **getMore** message."""
|
|
if ctx:
|
|
return _get_more_compressed(collection_name, num_to_return, cursor_id, ctx)
|
|
return _get_more_uncompressed(collection_name, num_to_return, cursor_id)
|
|
|
|
|
|
# OP_MSG -------------------------------------------------------------
|
|
|
|
|
|
_OP_MSG_MAP = {
|
|
_INSERT: b"documents\x00",
|
|
_UPDATE: b"updates\x00",
|
|
_DELETE: b"deletes\x00",
|
|
}
|
|
|
|
|
|
class _BulkWriteContextBase:
|
|
"""Private base class for wrapping around AsyncConnection to use with write splitting functions."""
|
|
|
|
__slots__ = (
|
|
"db_name",
|
|
"conn",
|
|
"op_id",
|
|
"name",
|
|
"field",
|
|
"publish",
|
|
"start_time",
|
|
"listeners",
|
|
"session",
|
|
"compress",
|
|
"op_type",
|
|
"codec",
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
database_name: str,
|
|
cmd_name: str,
|
|
conn: _AgnosticConnection,
|
|
operation_id: int,
|
|
listeners: _EventListeners,
|
|
session: Optional[_AgnosticClientSession],
|
|
op_type: int,
|
|
codec: CodecOptions,
|
|
):
|
|
self.db_name = database_name
|
|
self.conn = conn
|
|
self.op_id = operation_id
|
|
self.listeners = listeners
|
|
self.publish = listeners.enabled_for_commands
|
|
self.name = cmd_name
|
|
self.field = _FIELD_MAP[self.name]
|
|
self.start_time = datetime.datetime.now()
|
|
self.session = session
|
|
self.compress = bool(conn.compression_context)
|
|
self.op_type = op_type
|
|
self.codec = codec
|
|
|
|
@property
|
|
def max_bson_size(self) -> int:
|
|
"""A proxy for SockInfo.max_bson_size."""
|
|
return self.conn.max_bson_size
|
|
|
|
@property
|
|
def max_message_size(self) -> int:
|
|
"""A proxy for SockInfo.max_message_size."""
|
|
if self.compress:
|
|
# Subtract 16 bytes for the message header.
|
|
return self.conn.max_message_size - 16
|
|
return self.conn.max_message_size
|
|
|
|
@property
|
|
def max_write_batch_size(self) -> int:
|
|
"""A proxy for SockInfo.max_write_batch_size."""
|
|
return self.conn.max_write_batch_size
|
|
|
|
@property
|
|
def max_split_size(self) -> int:
|
|
"""The maximum size of a BSON command before batch splitting."""
|
|
return self.max_bson_size
|
|
|
|
def _succeed(self, request_id: int, reply: _DocumentOut, duration: datetime.timedelta) -> None:
|
|
"""Publish a CommandSucceededEvent."""
|
|
self.listeners.publish_command_success(
|
|
duration,
|
|
reply,
|
|
self.name,
|
|
request_id,
|
|
self.conn.address,
|
|
self.conn.server_connection_id,
|
|
self.op_id,
|
|
self.conn.service_id,
|
|
database_name=self.db_name,
|
|
)
|
|
|
|
def _fail(self, request_id: int, failure: _DocumentOut, duration: datetime.timedelta) -> None:
|
|
"""Publish a CommandFailedEvent."""
|
|
self.listeners.publish_command_failure(
|
|
duration,
|
|
failure,
|
|
self.name,
|
|
request_id,
|
|
self.conn.address,
|
|
self.conn.server_connection_id,
|
|
self.op_id,
|
|
self.conn.service_id,
|
|
database_name=self.db_name,
|
|
)
|
|
|
|
|
|
class _BulkWriteContext(_BulkWriteContextBase):
|
|
"""A wrapper around AsyncConnection/Connection for use with the collection-level bulk write API."""
|
|
|
|
__slots__ = ()
|
|
|
|
def __init__(
|
|
self,
|
|
database_name: str,
|
|
cmd_name: str,
|
|
conn: _AgnosticConnection,
|
|
operation_id: int,
|
|
listeners: _EventListeners,
|
|
session: Optional[_AgnosticClientSession],
|
|
op_type: int,
|
|
codec: CodecOptions,
|
|
):
|
|
super().__init__(
|
|
database_name,
|
|
cmd_name,
|
|
conn,
|
|
operation_id,
|
|
listeners,
|
|
session,
|
|
op_type,
|
|
codec,
|
|
)
|
|
|
|
def batch_command(
|
|
self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]]
|
|
) -> tuple[int, Union[bytes, dict[str, Any]], list[Mapping[str, Any]]]:
|
|
namespace = self.db_name + ".$cmd"
|
|
request_id, msg, to_send = _do_batched_op_msg(
|
|
namespace, self.op_type, cmd, docs, self.codec, self
|
|
)
|
|
if not to_send:
|
|
raise InvalidOperation("cannot do an empty bulk write")
|
|
return request_id, msg, to_send
|
|
|
|
def _start(
|
|
self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]]
|
|
) -> MutableMapping[str, Any]:
|
|
"""Publish a CommandStartedEvent."""
|
|
cmd[self.field] = docs
|
|
self.listeners.publish_command_start(
|
|
cmd,
|
|
self.db_name,
|
|
request_id,
|
|
self.conn.address,
|
|
self.conn.server_connection_id,
|
|
self.op_id,
|
|
self.conn.service_id,
|
|
)
|
|
return cmd
|
|
|
|
|
|
class _EncryptedBulkWriteContext(_BulkWriteContext):
|
|
__slots__ = ()
|
|
|
|
def batch_command(
|
|
self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]]
|
|
) -> tuple[int, dict[str, Any], list[Mapping[str, Any]]]:
|
|
namespace = self.db_name + ".$cmd"
|
|
msg, to_send = _encode_batched_write_command(
|
|
namespace, self.op_type, cmd, docs, self.codec, self
|
|
)
|
|
if not to_send:
|
|
raise InvalidOperation("cannot do an empty bulk write")
|
|
|
|
# Chop off the OP_QUERY header to get a properly batched write command.
|
|
cmd_start = msg.index(b"\x00", 4) + 9
|
|
outgoing = _inflate_bson(memoryview(msg)[cmd_start:], DEFAULT_RAW_BSON_OPTIONS)
|
|
return -1, outgoing, to_send
|
|
|
|
@property
|
|
def max_split_size(self) -> int:
|
|
"""Reduce the batch splitting size."""
|
|
return _MAX_SPLIT_SIZE_ENC
|
|
|
|
|
|
def _raise_document_too_large(operation: str, doc_size: int, max_size: int) -> NoReturn:
|
|
"""Internal helper for raising DocumentTooLarge."""
|
|
if operation == "insert":
|
|
raise DocumentTooLarge(
|
|
"BSON document too large (%d bytes)"
|
|
" - the connected server supports"
|
|
" BSON document sizes up to %d"
|
|
" bytes." % (doc_size, max_size)
|
|
)
|
|
else:
|
|
# There's nothing intelligent we can say
|
|
# about size for update and delete
|
|
raise DocumentTooLarge(f"{operation!r} command document too large")
|
|
|
|
|
|
# From the Client Side Encryption spec:
|
|
# Because automatic encryption increases the size of commands, the driver
|
|
# MUST split bulk writes at a reduced size limit before undergoing automatic
|
|
# encryption. The write payload MUST be split at 2MiB (2097152).
|
|
_MAX_SPLIT_SIZE_ENC = 2097152
|
|
|
|
|
|
def _batched_op_msg_impl(
|
|
operation: int,
|
|
command: Mapping[str, Any],
|
|
docs: list[Mapping[str, Any]],
|
|
ack: bool,
|
|
opts: CodecOptions,
|
|
ctx: _BulkWriteContext,
|
|
buf: _BytesIO,
|
|
) -> tuple[list[Mapping[str, Any]], int]:
|
|
"""Create a batched OP_MSG write."""
|
|
max_bson_size = ctx.max_bson_size
|
|
max_write_batch_size = ctx.max_write_batch_size
|
|
max_message_size = ctx.max_message_size
|
|
|
|
flags = b"\x00\x00\x00\x00" if ack else b"\x02\x00\x00\x00"
|
|
# Flags
|
|
buf.write(flags)
|
|
|
|
# Type 0 Section
|
|
buf.write(b"\x00")
|
|
buf.write(_dict_to_bson(command, False, opts))
|
|
|
|
# Type 1 Section
|
|
buf.write(b"\x01")
|
|
size_location = buf.tell()
|
|
# Save space for size
|
|
buf.write(b"\x00\x00\x00\x00")
|
|
try:
|
|
buf.write(_OP_MSG_MAP[operation])
|
|
except KeyError:
|
|
raise InvalidOperation("Unknown command") from None
|
|
|
|
to_send = []
|
|
idx = 0
|
|
for doc in docs:
|
|
# Encode the current operation
|
|
value = _dict_to_bson(doc, False, opts)
|
|
doc_length = len(value)
|
|
new_message_size = buf.tell() + doc_length
|
|
# Does first document exceed max_message_size?
|
|
doc_too_large = idx == 0 and (new_message_size > max_message_size)
|
|
# When OP_MSG is used unacknowledged we have to check
|
|
# document size client side or applications won't be notified.
|
|
# Otherwise we let the server deal with documents that are too large
|
|
# since ordered=False causes those documents to be skipped instead of
|
|
# halting the bulk write operation.
|
|
unacked_doc_too_large = not ack and (doc_length > max_bson_size)
|
|
if doc_too_large or unacked_doc_too_large:
|
|
write_op = list(_FIELD_MAP.keys())[operation]
|
|
_raise_document_too_large(write_op, len(value), max_bson_size)
|
|
# We have enough data, return this batch.
|
|
if new_message_size > max_message_size:
|
|
break
|
|
buf.write(value)
|
|
to_send.append(doc)
|
|
idx += 1
|
|
# We have enough documents, return this batch.
|
|
if idx == max_write_batch_size:
|
|
break
|
|
|
|
# Write type 1 section size
|
|
length = buf.tell()
|
|
buf.seek(size_location)
|
|
buf.write(_pack_int(length - size_location))
|
|
|
|
return to_send, length
|
|
|
|
|
|
def _encode_batched_op_msg(
|
|
operation: int,
|
|
command: Mapping[str, Any],
|
|
docs: list[Mapping[str, Any]],
|
|
ack: bool,
|
|
opts: CodecOptions,
|
|
ctx: _BulkWriteContext,
|
|
) -> tuple[bytes, list[Mapping[str, Any]]]:
|
|
"""Encode the next batched insert, update, or delete operation
|
|
as OP_MSG.
|
|
"""
|
|
buf = _BytesIO()
|
|
|
|
to_send, _ = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf)
|
|
return buf.getvalue(), to_send
|
|
|
|
|
|
if _use_c:
|
|
_encode_batched_op_msg = _cmessage._encode_batched_op_msg
|
|
|
|
|
|
def _batched_op_msg_compressed(
|
|
operation: int,
|
|
command: Mapping[str, Any],
|
|
docs: list[Mapping[str, Any]],
|
|
ack: bool,
|
|
opts: CodecOptions,
|
|
ctx: _BulkWriteContext,
|
|
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
|
|
"""Create the next batched insert, update, or delete operation
|
|
with OP_MSG, compressed.
|
|
"""
|
|
data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx)
|
|
|
|
assert ctx.conn.compression_context is not None
|
|
request_id, msg = _compress(2013, data, ctx.conn.compression_context)
|
|
return request_id, msg, to_send
|
|
|
|
|
|
def _batched_op_msg(
|
|
operation: int,
|
|
command: Mapping[str, Any],
|
|
docs: list[Mapping[str, Any]],
|
|
ack: bool,
|
|
opts: CodecOptions,
|
|
ctx: _BulkWriteContext,
|
|
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
|
|
"""OP_MSG implementation entry point."""
|
|
buf = _BytesIO()
|
|
|
|
# Save space for message length and request id
|
|
buf.write(_ZERO_64)
|
|
# responseTo, opCode
|
|
buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00")
|
|
|
|
to_send, length = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf)
|
|
|
|
# Header - request id and message length
|
|
buf.seek(4)
|
|
request_id = _randint()
|
|
buf.write(_pack_int(request_id))
|
|
buf.seek(0)
|
|
buf.write(_pack_int(length))
|
|
|
|
return request_id, buf.getvalue(), to_send
|
|
|
|
|
|
if _use_c:
|
|
_batched_op_msg = _cmessage._batched_op_msg
|
|
|
|
|
|
def _do_batched_op_msg(
|
|
namespace: str,
|
|
operation: int,
|
|
command: MutableMapping[str, Any],
|
|
docs: list[Mapping[str, Any]],
|
|
opts: CodecOptions,
|
|
ctx: _BulkWriteContext,
|
|
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
|
|
"""Create the next batched insert, update, or delete operation
|
|
using OP_MSG.
|
|
"""
|
|
command["$db"] = namespace.split(".", 1)[0]
|
|
if "writeConcern" in command:
|
|
ack = bool(command["writeConcern"].get("w", 1))
|
|
else:
|
|
ack = True
|
|
if ctx.conn.compression_context:
|
|
return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx)
|
|
return _batched_op_msg(operation, command, docs, ack, opts, ctx)
|
|
|
|
|
|
class _ClientBulkWriteContext(_BulkWriteContextBase):
|
|
"""A wrapper around AsyncConnection/Connection for use with the client-level bulk write API."""
|
|
|
|
__slots__ = ()
|
|
|
|
def __init__(
|
|
self,
|
|
database_name: str,
|
|
cmd_name: str,
|
|
conn: _AgnosticConnection,
|
|
operation_id: int,
|
|
listeners: _EventListeners,
|
|
session: Optional[_AgnosticClientSession],
|
|
codec: CodecOptions,
|
|
):
|
|
super().__init__(
|
|
database_name,
|
|
cmd_name,
|
|
conn,
|
|
operation_id,
|
|
listeners,
|
|
session,
|
|
0,
|
|
codec,
|
|
)
|
|
|
|
def batch_command(
|
|
self,
|
|
cmd: MutableMapping[str, Any],
|
|
operations: list[tuple[str, Mapping[str, Any]]],
|
|
namespaces: list[str],
|
|
) -> tuple[int, Union[bytes, dict[str, Any]], list[Mapping[str, Any]], list[Mapping[str, Any]]]:
|
|
request_id, msg, to_send_ops, to_send_ns = _client_do_batched_op_msg(
|
|
cmd, operations, namespaces, self.codec, self
|
|
)
|
|
if not to_send_ops:
|
|
raise InvalidOperation("cannot do an empty bulk write")
|
|
return request_id, msg, to_send_ops, to_send_ns
|
|
|
|
def _start(
|
|
self,
|
|
cmd: MutableMapping[str, Any],
|
|
request_id: int,
|
|
op_docs: list[Mapping[str, Any]],
|
|
ns_docs: list[Mapping[str, Any]],
|
|
) -> MutableMapping[str, Any]:
|
|
"""Publish a CommandStartedEvent."""
|
|
cmd["ops"] = op_docs
|
|
cmd["nsInfo"] = ns_docs
|
|
self.listeners.publish_command_start(
|
|
cmd,
|
|
self.db_name,
|
|
request_id,
|
|
self.conn.address,
|
|
self.conn.server_connection_id,
|
|
self.op_id,
|
|
self.conn.service_id,
|
|
)
|
|
return cmd
|
|
|
|
|
|
_OP_MSG_OVERHEAD = 1000
|
|
|
|
|
|
def _client_construct_op_msg(
|
|
command_encoded: bytes,
|
|
to_send_ops_encoded: list[bytes],
|
|
to_send_ns_encoded: list[bytes],
|
|
ack: bool,
|
|
buf: _BytesIO,
|
|
) -> int:
|
|
# Write flags
|
|
flags = b"\x00\x00\x00\x00" if ack else b"\x02\x00\x00\x00"
|
|
buf.write(flags)
|
|
|
|
# Type 0 Section
|
|
buf.write(b"\x00")
|
|
buf.write(command_encoded)
|
|
|
|
# Type 1 Section for ops
|
|
buf.write(b"\x01")
|
|
size_location = buf.tell()
|
|
# Save space for size
|
|
buf.write(b"\x00\x00\x00\x00")
|
|
buf.write(b"ops\x00")
|
|
# Write all the ops documents
|
|
for op_encoded in to_send_ops_encoded:
|
|
buf.write(op_encoded)
|
|
resume_location = buf.tell()
|
|
# Write type 1 section size
|
|
length = buf.tell()
|
|
buf.seek(size_location)
|
|
buf.write(_pack_int(length - size_location))
|
|
buf.seek(resume_location)
|
|
|
|
# Type 1 Section for nsInfo
|
|
buf.write(b"\x01")
|
|
size_location = buf.tell()
|
|
# Save space for size
|
|
buf.write(b"\x00\x00\x00\x00")
|
|
buf.write(b"nsInfo\x00")
|
|
# Write all the nsInfo documents
|
|
for ns_encoded in to_send_ns_encoded:
|
|
buf.write(ns_encoded)
|
|
# Write type 1 section size
|
|
length = buf.tell()
|
|
buf.seek(size_location)
|
|
buf.write(_pack_int(length - size_location))
|
|
|
|
return length
|
|
|
|
|
|
def _client_batched_op_msg_impl(
|
|
command: Mapping[str, Any],
|
|
operations: list[tuple[str, Mapping[str, Any]]],
|
|
namespaces: list[str],
|
|
ack: bool,
|
|
opts: CodecOptions,
|
|
ctx: _ClientBulkWriteContext,
|
|
buf: _BytesIO,
|
|
) -> tuple[list[Mapping[str, Any]], list[Mapping[str, Any]], int]:
|
|
"""Create a batched OP_MSG write for client-level bulk write."""
|
|
|
|
def _check_doc_size_limits(
|
|
op_type: str,
|
|
doc_size: int,
|
|
limit: int,
|
|
) -> None:
|
|
if doc_size > limit:
|
|
_raise_document_too_large(op_type, doc_size, limit)
|
|
|
|
max_bson_size = ctx.max_bson_size
|
|
max_write_batch_size = ctx.max_write_batch_size
|
|
max_message_size = ctx.max_message_size
|
|
|
|
command_encoded = _dict_to_bson(command, False, opts)
|
|
# When OP_MSG is used unacknowledged we have to check command
|
|
# document size client-side or applications won't be notified.
|
|
if not ack:
|
|
_check_doc_size_limits("bulkWrite", len(command_encoded), max_bson_size + _COMMAND_OVERHEAD)
|
|
|
|
# Don't include bulkWrite-command-agnostic fields in batch-splitting calculations.
|
|
abridged_keys = ["bulkWrite", "errorsOnly", "ordered"]
|
|
if command.get("bypassDocumentValidation"):
|
|
abridged_keys.append("bypassDocumentValidation")
|
|
if command.get("comment"):
|
|
abridged_keys.append("comment")
|
|
if command.get("let"):
|
|
abridged_keys.append("let")
|
|
command_abridged = {key: command[key] for key in abridged_keys}
|
|
command_len_abridged = len(_dict_to_bson(command_abridged, False, opts))
|
|
|
|
# Maximum combined size of the ops and nsInfo document sequences.
|
|
max_doc_sequences_bytes = max_message_size - (_OP_MSG_OVERHEAD + command_len_abridged)
|
|
|
|
ns_info = {}
|
|
to_send_ops: list[Mapping[str, Any]] = []
|
|
to_send_ns: list[Mapping[str, str]] = []
|
|
to_send_ops_encoded: list[bytes] = []
|
|
to_send_ns_encoded: list[bytes] = []
|
|
total_ops_length = 0
|
|
total_ns_length = 0
|
|
idx = 0
|
|
|
|
for (real_op_type, op_doc), namespace in zip(operations, namespaces):
|
|
op_type = real_op_type
|
|
# Check insert/replace document size if unacknowledged.
|
|
if real_op_type == "insert":
|
|
if not ack:
|
|
doc_size = len(_dict_to_bson(op_doc["document"], False, opts))
|
|
_check_doc_size_limits(real_op_type, doc_size, max_bson_size)
|
|
if real_op_type == "replace":
|
|
op_type = "update"
|
|
if not ack:
|
|
doc_size = len(_dict_to_bson(op_doc["updateMods"], False, opts))
|
|
_check_doc_size_limits(real_op_type, doc_size, max_bson_size)
|
|
|
|
ns_doc = None
|
|
ns_length = 0
|
|
|
|
if namespace not in ns_info:
|
|
ns_doc = {"ns": namespace}
|
|
new_ns_index = len(to_send_ns)
|
|
ns_info[namespace] = new_ns_index
|
|
|
|
# First entry in the operation doc has the operation type as its
|
|
# key and the index of its namespace within ns_info as its value.
|
|
op_doc[op_type] = ns_info[namespace] # type: ignore[index]
|
|
|
|
# Encode current operation doc and, if newly added, namespace doc.
|
|
op_doc_encoded = _dict_to_bson(op_doc, False, opts)
|
|
op_length = len(op_doc_encoded)
|
|
if ns_doc:
|
|
ns_doc_encoded = _dict_to_bson(ns_doc, False, opts)
|
|
ns_length = len(ns_doc_encoded)
|
|
|
|
# Check operation document size if unacknowledged.
|
|
if not ack:
|
|
_check_doc_size_limits(op_type, op_length, max_bson_size + _COMMAND_OVERHEAD)
|
|
|
|
new_message_size = total_ops_length + total_ns_length + op_length + ns_length
|
|
# We have enough data, return this batch.
|
|
if new_message_size > max_doc_sequences_bytes:
|
|
if idx == 0:
|
|
_raise_document_too_large(op_type, op_length, max_bson_size + _COMMAND_OVERHEAD)
|
|
break
|
|
|
|
# Add op and ns documents to this batch.
|
|
to_send_ops.append(op_doc)
|
|
to_send_ops_encoded.append(op_doc_encoded)
|
|
total_ops_length += op_length
|
|
if ns_doc:
|
|
to_send_ns.append(ns_doc)
|
|
to_send_ns_encoded.append(ns_doc_encoded)
|
|
total_ns_length += ns_length
|
|
|
|
idx += 1
|
|
|
|
# We have enough documents, return this batch.
|
|
if idx == max_write_batch_size:
|
|
break
|
|
|
|
# Construct the entire OP_MSG.
|
|
length = _client_construct_op_msg(
|
|
command_encoded, to_send_ops_encoded, to_send_ns_encoded, ack, buf
|
|
)
|
|
|
|
return to_send_ops, to_send_ns, length
|
|
|
|
|
|
def _client_encode_batched_op_msg(
|
|
command: Mapping[str, Any],
|
|
operations: list[tuple[str, Mapping[str, Any]]],
|
|
namespaces: list[str],
|
|
ack: bool,
|
|
opts: CodecOptions,
|
|
ctx: _ClientBulkWriteContext,
|
|
) -> tuple[bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
|
|
"""Encode the next batched client-level bulkWrite
|
|
operation as OP_MSG.
|
|
"""
|
|
buf = _BytesIO()
|
|
|
|
to_send_ops, to_send_ns, _ = _client_batched_op_msg_impl(
|
|
command, operations, namespaces, ack, opts, ctx, buf
|
|
)
|
|
return buf.getvalue(), to_send_ops, to_send_ns
|
|
|
|
|
|
def _client_batched_op_msg_compressed(
|
|
command: Mapping[str, Any],
|
|
operations: list[tuple[str, Mapping[str, Any]]],
|
|
namespaces: list[str],
|
|
ack: bool,
|
|
opts: CodecOptions,
|
|
ctx: _ClientBulkWriteContext,
|
|
) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
|
|
"""Create the next batched client-level bulkWrite operation
|
|
with OP_MSG, compressed.
|
|
"""
|
|
data, to_send_ops, to_send_ns = _client_encode_batched_op_msg(
|
|
command, operations, namespaces, ack, opts, ctx
|
|
)
|
|
|
|
assert ctx.conn.compression_context is not None
|
|
request_id, msg = _compress(2013, data, ctx.conn.compression_context)
|
|
return request_id, msg, to_send_ops, to_send_ns
|
|
|
|
|
|
def _client_batched_op_msg(
|
|
command: Mapping[str, Any],
|
|
operations: list[tuple[str, Mapping[str, Any]]],
|
|
namespaces: list[str],
|
|
ack: bool,
|
|
opts: CodecOptions,
|
|
ctx: _ClientBulkWriteContext,
|
|
) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
|
|
"""OP_MSG implementation entry point for client-level bulkWrite."""
|
|
buf = _BytesIO()
|
|
|
|
# Save space for message length and request id
|
|
buf.write(_ZERO_64)
|
|
# responseTo, opCode
|
|
buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00")
|
|
|
|
to_send_ops, to_send_ns, length = _client_batched_op_msg_impl(
|
|
command, operations, namespaces, ack, opts, ctx, buf
|
|
)
|
|
|
|
# Header - request id and message length
|
|
buf.seek(4)
|
|
request_id = _randint()
|
|
buf.write(_pack_int(request_id))
|
|
buf.seek(0)
|
|
buf.write(_pack_int(length))
|
|
|
|
return request_id, buf.getvalue(), to_send_ops, to_send_ns
|
|
|
|
|
|
def _client_do_batched_op_msg(
|
|
command: MutableMapping[str, Any],
|
|
operations: list[tuple[str, Mapping[str, Any]]],
|
|
namespaces: list[str],
|
|
opts: CodecOptions,
|
|
ctx: _ClientBulkWriteContext,
|
|
) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
|
|
"""Create the next batched client-level bulkWrite
|
|
operation using OP_MSG.
|
|
"""
|
|
command["$db"] = "admin"
|
|
if "writeConcern" in command:
|
|
ack = bool(command["writeConcern"].get("w", 1))
|
|
else:
|
|
ack = True
|
|
if ctx.conn.compression_context:
|
|
return _client_batched_op_msg_compressed(command, operations, namespaces, ack, opts, ctx)
|
|
return _client_batched_op_msg(command, operations, namespaces, ack, opts, ctx)
|
|
|
|
|
|
# End OP_MSG -----------------------------------------------------
|
|
|
|
|
|
def _encode_batched_write_command(
|
|
namespace: str,
|
|
operation: int,
|
|
command: MutableMapping[str, Any],
|
|
docs: list[Mapping[str, Any]],
|
|
opts: CodecOptions,
|
|
ctx: _BulkWriteContext,
|
|
) -> tuple[bytes, list[Mapping[str, Any]]]:
|
|
"""Encode the next batched insert, update, or delete command."""
|
|
buf = _BytesIO()
|
|
|
|
to_send, _ = _batched_write_command_impl(namespace, operation, command, docs, opts, ctx, buf)
|
|
return buf.getvalue(), to_send
|
|
|
|
|
|
if _use_c:
|
|
_encode_batched_write_command = _cmessage._encode_batched_write_command
|
|
|
|
|
|
def _batched_write_command_impl(
|
|
namespace: str,
|
|
operation: int,
|
|
command: MutableMapping[str, Any],
|
|
docs: list[Mapping[str, Any]],
|
|
opts: CodecOptions,
|
|
ctx: _BulkWriteContext,
|
|
buf: _BytesIO,
|
|
) -> tuple[list[Mapping[str, Any]], int]:
|
|
"""Create a batched OP_QUERY write command."""
|
|
max_bson_size = ctx.max_bson_size
|
|
max_write_batch_size = ctx.max_write_batch_size
|
|
# Max BSON object size + 16k - 2 bytes for ending NUL bytes.
|
|
# Server guarantees there is enough room: SERVER-10643.
|
|
max_cmd_size = max_bson_size + _COMMAND_OVERHEAD
|
|
max_split_size = ctx.max_split_size
|
|
|
|
# No options
|
|
buf.write(_ZERO_32)
|
|
# Namespace as C string
|
|
buf.write(namespace.encode("utf8"))
|
|
buf.write(_ZERO_8)
|
|
# Skip: 0, Limit: -1
|
|
buf.write(_SKIPLIM)
|
|
|
|
# Where to write command document length
|
|
command_start = buf.tell()
|
|
buf.write(bson.encode(command))
|
|
|
|
# Start of payload
|
|
buf.seek(-1, 2)
|
|
# Work around some Jython weirdness.
|
|
buf.truncate()
|
|
try:
|
|
buf.write(_OP_MAP[operation])
|
|
except KeyError:
|
|
raise InvalidOperation("Unknown command") from None
|
|
|
|
# Where to write list document length
|
|
list_start = buf.tell() - 4
|
|
to_send = []
|
|
idx = 0
|
|
for doc in docs:
|
|
# Encode the current operation
|
|
key = str(idx).encode("utf8")
|
|
value = _dict_to_bson(doc, False, opts)
|
|
# Is there enough room to add this document? max_cmd_size accounts for
|
|
# the two trailing null bytes.
|
|
doc_too_large = len(value) > max_cmd_size
|
|
if doc_too_large:
|
|
write_op = list(_FIELD_MAP.keys())[operation]
|
|
_raise_document_too_large(write_op, len(value), max_bson_size)
|
|
enough_data = idx >= 1 and (buf.tell() + len(key) + len(value)) >= max_split_size
|
|
enough_documents = idx >= max_write_batch_size
|
|
if enough_data or enough_documents:
|
|
break
|
|
buf.write(_BSONOBJ)
|
|
buf.write(key)
|
|
buf.write(_ZERO_8)
|
|
buf.write(value)
|
|
to_send.append(doc)
|
|
idx += 1
|
|
|
|
# Finalize the current OP_QUERY message.
|
|
# Close list and command documents
|
|
buf.write(_ZERO_16)
|
|
|
|
# Write document lengths and request id
|
|
length = buf.tell()
|
|
buf.seek(list_start)
|
|
buf.write(_pack_int(length - list_start - 1))
|
|
buf.seek(command_start)
|
|
buf.write(_pack_int(length - command_start))
|
|
|
|
return to_send, length
|
|
|
|
|
|
class _OpReply:
|
|
"""A MongoDB OP_REPLY response message."""
|
|
|
|
__slots__ = ("flags", "cursor_id", "number_returned", "documents")
|
|
|
|
UNPACK_FROM = struct.Struct("<iqii").unpack_from
|
|
OP_CODE = 1
|
|
|
|
def __init__(self, flags: int, cursor_id: int, number_returned: int, documents: bytes):
|
|
self.flags = flags
|
|
self.cursor_id = Int64(cursor_id)
|
|
self.number_returned = number_returned
|
|
self.documents = documents
|
|
|
|
def raw_response(
|
|
self, cursor_id: Optional[int] = None, user_fields: Optional[Mapping[str, Any]] = None
|
|
) -> list[bytes]:
|
|
"""Check the response header from the database, without decoding BSON.
|
|
|
|
Check the response for errors and unpack.
|
|
|
|
Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or
|
|
OperationFailure.
|
|
|
|
:param cursor_id: cursor_id we sent to get this response -
|
|
used for raising an informative exception when we get cursor id not
|
|
valid at server response.
|
|
"""
|
|
if self.flags & 1:
|
|
# Shouldn't get this response if we aren't doing a getMore
|
|
if cursor_id is None:
|
|
raise ProtocolError("No cursor id for getMore operation")
|
|
|
|
# Fake a getMore command response. OP_GET_MORE provides no
|
|
# document.
|
|
msg = "Cursor not found, cursor id: %d" % (cursor_id,)
|
|
errobj = {"ok": 0, "errmsg": msg, "code": 43}
|
|
raise CursorNotFound(msg, 43, errobj)
|
|
elif self.flags & 2:
|
|
error_object: dict = bson.BSON(self.documents).decode()
|
|
# Fake the ok field if it doesn't exist.
|
|
error_object.setdefault("ok", 0)
|
|
if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR):
|
|
raise NotPrimaryError(error_object["$err"], error_object)
|
|
elif error_object.get("code") == 50:
|
|
default_msg = "operation exceeded time limit"
|
|
raise ExecutionTimeout(
|
|
error_object.get("$err", default_msg), error_object.get("code"), error_object
|
|
)
|
|
raise OperationFailure(
|
|
"database error: %s" % error_object.get("$err"),
|
|
error_object.get("code"),
|
|
error_object,
|
|
)
|
|
if self.documents:
|
|
return [self.documents]
|
|
return []
|
|
|
|
def unpack_response(
|
|
self,
|
|
cursor_id: Optional[int] = None,
|
|
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
|
|
user_fields: Optional[Mapping[str, Any]] = None,
|
|
legacy_response: bool = False,
|
|
) -> list[dict[str, Any]]:
|
|
"""Unpack a response from the database and decode the BSON document(s).
|
|
|
|
Check the response for errors and unpack, returning a dictionary
|
|
containing the response data.
|
|
|
|
Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or
|
|
OperationFailure.
|
|
|
|
:param cursor_id: cursor_id we sent to get this response -
|
|
used for raising an informative exception when we get cursor id not
|
|
valid at server response
|
|
:param codec_options: an instance of
|
|
:class:`~bson.codec_options.CodecOptions`
|
|
:param user_fields: Response fields that should be decoded
|
|
using the TypeDecoders from codec_options, passed to
|
|
bson._decode_all_selective.
|
|
"""
|
|
self.raw_response(cursor_id)
|
|
if legacy_response:
|
|
return bson.decode_all(self.documents, codec_options)
|
|
return bson._decode_all_selective(self.documents, codec_options, user_fields)
|
|
|
|
def command_response(self, codec_options: CodecOptions) -> dict[str, Any]:
|
|
"""Unpack a command response."""
|
|
docs = self.unpack_response(codec_options=codec_options)
|
|
assert self.number_returned == 1
|
|
return docs[0]
|
|
|
|
def raw_command_response(self) -> NoReturn:
|
|
"""Return the bytes of the command response."""
|
|
# This should never be called on _OpReply.
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def more_to_come(self) -> bool:
|
|
"""Is the moreToCome bit set on this response?"""
|
|
return False
|
|
|
|
@classmethod
|
|
def unpack(cls, msg: bytes) -> _OpReply:
|
|
"""Construct an _OpReply from raw bytes."""
|
|
# PYTHON-945: ignore starting_from field.
|
|
flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg)
|
|
|
|
documents = msg[20:]
|
|
return cls(flags, cursor_id, number_returned, documents)
|
|
|
|
|
|
class _OpMsg:
|
|
"""A MongoDB OP_MSG response message."""
|
|
|
|
__slots__ = ("flags", "cursor_id", "number_returned", "payload_document")
|
|
|
|
UNPACK_FROM = struct.Struct("<IBi").unpack_from
|
|
OP_CODE = 2013
|
|
|
|
# Flag bits.
|
|
CHECKSUM_PRESENT = 1
|
|
MORE_TO_COME = 1 << 1
|
|
EXHAUST_ALLOWED = 1 << 16 # Only present on requests.
|
|
|
|
def __init__(self, flags: int, payload_document: bytes):
|
|
self.flags = flags
|
|
self.payload_document = payload_document
|
|
|
|
def raw_response(
|
|
self,
|
|
cursor_id: Optional[int] = None,
|
|
user_fields: Optional[Mapping[str, Any]] = {},
|
|
) -> list[Mapping[str, Any]]:
|
|
"""
|
|
cursor_id is ignored
|
|
user_fields is used to determine which fields must not be decoded
|
|
"""
|
|
inflated_response = bson._decode_selective(
|
|
RawBSONDocument(self.payload_document), user_fields, _RAW_ARRAY_BSON_OPTIONS
|
|
)
|
|
return [inflated_response]
|
|
|
|
def unpack_response(
|
|
self,
|
|
cursor_id: Optional[int] = None,
|
|
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
|
|
user_fields: Optional[Mapping[str, Any]] = None,
|
|
legacy_response: bool = False,
|
|
) -> list[dict[str, Any]]:
|
|
"""Unpack a OP_MSG command response.
|
|
|
|
:param cursor_id: Ignored, for compatibility with _OpReply.
|
|
:param codec_options: an instance of
|
|
:class:`~bson.codec_options.CodecOptions`
|
|
:param user_fields: Response fields that should be decoded
|
|
using the TypeDecoders from codec_options, passed to
|
|
bson._decode_all_selective.
|
|
"""
|
|
# If _OpMsg is in-use, this cannot be a legacy response.
|
|
assert not legacy_response
|
|
return bson._decode_all_selective(self.payload_document, codec_options, user_fields)
|
|
|
|
def command_response(self, codec_options: CodecOptions) -> dict[str, Any]:
|
|
"""Unpack a command response."""
|
|
return self.unpack_response(codec_options=codec_options)[0]
|
|
|
|
def raw_command_response(self) -> bytes:
|
|
"""Return the bytes of the command response."""
|
|
return self.payload_document
|
|
|
|
@property
|
|
def more_to_come(self) -> bool:
|
|
"""Is the moreToCome bit set on this response?"""
|
|
return bool(self.flags & self.MORE_TO_COME)
|
|
|
|
@classmethod
|
|
def unpack(cls, msg: bytes) -> _OpMsg:
|
|
"""Construct an _OpMsg from raw bytes."""
|
|
flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg)
|
|
if flags != 0:
|
|
if flags & cls.CHECKSUM_PRESENT:
|
|
raise ProtocolError(f"Unsupported OP_MSG flag checksumPresent: 0x{flags:x}")
|
|
|
|
if flags ^ cls.MORE_TO_COME:
|
|
raise ProtocolError(f"Unsupported OP_MSG flags: 0x{flags:x}")
|
|
if first_payload_type != 0:
|
|
raise ProtocolError(f"Unsupported OP_MSG payload type: 0x{first_payload_type:x}")
|
|
|
|
if len(msg) != first_payload_size + 5:
|
|
raise ProtocolError("Unsupported OP_MSG reply: >1 section")
|
|
|
|
payload_document = msg[5:]
|
|
return cls(flags, payload_document)
|
|
|
|
|
|
_UNPACK_REPLY: dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = {
|
|
_OpReply.OP_CODE: _OpReply.unpack,
|
|
_OpMsg.OP_CODE: _OpMsg.unpack,
|
|
}
|
|
|
|
|
|
class _Query:
|
|
"""A query operation."""
|
|
|
|
__slots__ = (
|
|
"flags",
|
|
"db",
|
|
"coll",
|
|
"ntoskip",
|
|
"spec",
|
|
"fields",
|
|
"codec_options",
|
|
"read_preference",
|
|
"limit",
|
|
"batch_size",
|
|
"name",
|
|
"read_concern",
|
|
"collation",
|
|
"session",
|
|
"client",
|
|
"allow_disk_use",
|
|
"_as_command",
|
|
"exhaust",
|
|
)
|
|
|
|
# For compatibility with the _GetMore class.
|
|
conn_mgr = None
|
|
cursor_id = None
|
|
|
|
def __init__(
|
|
self,
|
|
flags: int,
|
|
db: str,
|
|
coll: str,
|
|
ntoskip: int,
|
|
spec: Mapping[str, Any],
|
|
fields: Optional[Mapping[str, Any]],
|
|
codec_options: CodecOptions,
|
|
read_preference: _ServerMode,
|
|
limit: int,
|
|
batch_size: int,
|
|
read_concern: ReadConcern,
|
|
collation: Optional[Mapping[str, Any]],
|
|
session: Optional[_AgnosticClientSession],
|
|
client: _AgnosticMongoClient,
|
|
allow_disk_use: Optional[bool],
|
|
exhaust: bool,
|
|
):
|
|
self.flags = flags
|
|
self.db = db
|
|
self.coll = coll
|
|
self.ntoskip = ntoskip
|
|
self.spec = spec
|
|
self.fields = fields
|
|
self.codec_options = codec_options
|
|
self.read_preference = read_preference
|
|
self.read_concern = read_concern
|
|
self.limit = limit
|
|
self.batch_size = batch_size
|
|
self.collation = collation
|
|
self.session = session
|
|
self.client = client
|
|
self.allow_disk_use = allow_disk_use
|
|
self.name = "find"
|
|
self._as_command: Optional[tuple[dict[str, Any], str]] = None
|
|
self.exhaust = exhaust
|
|
|
|
def reset(self) -> None:
|
|
self._as_command = None
|
|
|
|
def namespace(self) -> str:
|
|
return f"{self.db}.{self.coll}"
|
|
|
|
def use_command(self, conn: _AgnosticConnection) -> bool:
|
|
use_find_cmd = False
|
|
if not self.exhaust:
|
|
use_find_cmd = True
|
|
elif conn.max_wire_version >= 8:
|
|
# OP_MSG supports exhaust on MongoDB 4.2+
|
|
use_find_cmd = True
|
|
elif not self.read_concern.ok_for_legacy:
|
|
raise ConfigurationError(
|
|
"read concern level of %s is not valid "
|
|
"with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version)
|
|
)
|
|
|
|
conn.validate_session(self.client, self.session) # type: ignore[arg-type]
|
|
return use_find_cmd
|
|
|
|
def update_command(self, cmd: dict[str, Any]) -> None:
|
|
self._as_command = cmd, self.db
|
|
|
|
def as_command(
|
|
self, conn: _AgnosticConnection, apply_timeout: bool = False
|
|
) -> tuple[dict[str, Any], str]:
|
|
"""Return a find command document for this query."""
|
|
# We use the command twice: on the wire and for command monitoring.
|
|
# Generate it once, for speed and to avoid repeating side-effects.
|
|
if self._as_command is not None:
|
|
return self._as_command
|
|
|
|
explain = "$explain" in self.spec
|
|
cmd: dict[str, Any] = _gen_find_command(
|
|
self.coll,
|
|
self.spec,
|
|
self.fields,
|
|
self.ntoskip,
|
|
self.limit,
|
|
self.batch_size,
|
|
self.flags,
|
|
self.read_concern,
|
|
self.collation,
|
|
self.session,
|
|
self.allow_disk_use,
|
|
)
|
|
if explain:
|
|
self.name = "explain"
|
|
cmd = {"explain": cmd}
|
|
conn.add_server_api(cmd)
|
|
if self.session:
|
|
self.session._apply_to(cmd, False, self.read_preference, conn) # type: ignore[arg-type]
|
|
# Explain does not support readConcern.
|
|
if not explain and not self.session.in_transaction:
|
|
self.session._update_read_concern(cmd, conn) # type: ignore[arg-type]
|
|
conn.send_cluster_time(cmd, self.session, self.client) # type: ignore[arg-type]
|
|
# Support CSOT
|
|
if apply_timeout:
|
|
conn.apply_timeout(self.client, cmd=cmd) # type: ignore[arg-type]
|
|
self._as_command = cmd, self.db
|
|
return self._as_command
|
|
|
|
def get_message(
|
|
self, read_preference: _ServerMode, conn: _AgnosticConnection, use_cmd: bool = False
|
|
) -> tuple[int, bytes, int]:
|
|
"""Get a query message, possibly setting the secondaryOk bit."""
|
|
# Use the read_preference decided by _socket_from_server.
|
|
self.read_preference = read_preference
|
|
if read_preference.mode:
|
|
# Set the secondaryOk bit.
|
|
flags = self.flags | 4
|
|
else:
|
|
flags = self.flags
|
|
|
|
ns = self.namespace()
|
|
spec = self.spec
|
|
|
|
if use_cmd:
|
|
spec = self.as_command(conn)[0]
|
|
request_id, msg, size, _ = _op_msg(
|
|
0,
|
|
spec,
|
|
self.db,
|
|
read_preference,
|
|
self.codec_options,
|
|
ctx=conn.compression_context,
|
|
)
|
|
return request_id, msg, size
|
|
|
|
# OP_QUERY treats ntoreturn of -1 and 1 the same, return
|
|
# one document and close the cursor. We have to use 2 for
|
|
# batch size if 1 is specified.
|
|
ntoreturn = self.batch_size == 1 and 2 or self.batch_size
|
|
if self.limit:
|
|
if ntoreturn:
|
|
ntoreturn = min(self.limit, ntoreturn)
|
|
else:
|
|
ntoreturn = self.limit
|
|
|
|
if conn.is_mongos:
|
|
assert isinstance(spec, MutableMapping)
|
|
spec = _maybe_add_read_preference(spec, read_preference)
|
|
|
|
return _query(
|
|
flags,
|
|
ns,
|
|
self.ntoskip,
|
|
ntoreturn,
|
|
spec,
|
|
None if use_cmd else self.fields,
|
|
self.codec_options,
|
|
ctx=conn.compression_context,
|
|
)
|
|
|
|
|
|
class _GetMore:
|
|
"""A getmore operation."""
|
|
|
|
__slots__ = (
|
|
"db",
|
|
"coll",
|
|
"ntoreturn",
|
|
"cursor_id",
|
|
"max_await_time_ms",
|
|
"codec_options",
|
|
"read_preference",
|
|
"session",
|
|
"client",
|
|
"conn_mgr",
|
|
"_as_command",
|
|
"exhaust",
|
|
"comment",
|
|
)
|
|
|
|
name = "getMore"
|
|
|
|
def __init__(
|
|
self,
|
|
db: str,
|
|
coll: str,
|
|
ntoreturn: int,
|
|
cursor_id: int,
|
|
codec_options: CodecOptions,
|
|
read_preference: _ServerMode,
|
|
session: Optional[_AgnosticClientSession],
|
|
client: _AgnosticMongoClient,
|
|
max_await_time_ms: Optional[int],
|
|
conn_mgr: Any,
|
|
exhaust: bool,
|
|
comment: Any,
|
|
):
|
|
self.db = db
|
|
self.coll = coll
|
|
self.ntoreturn = ntoreturn
|
|
self.cursor_id = cursor_id
|
|
self.codec_options = codec_options
|
|
self.read_preference = read_preference
|
|
self.session = session
|
|
self.client = client
|
|
self.max_await_time_ms = max_await_time_ms
|
|
self.conn_mgr = conn_mgr
|
|
self._as_command: Optional[tuple[dict[str, Any], str]] = None
|
|
self.exhaust = exhaust
|
|
self.comment = comment
|
|
|
|
def reset(self) -> None:
|
|
self._as_command = None
|
|
|
|
def namespace(self) -> str:
|
|
return f"{self.db}.{self.coll}"
|
|
|
|
def use_command(self, conn: _AgnosticConnection) -> bool:
|
|
use_cmd = False
|
|
if not self.exhaust:
|
|
use_cmd = True
|
|
elif conn.max_wire_version >= 8:
|
|
# OP_MSG supports exhaust on MongoDB 4.2+
|
|
use_cmd = True
|
|
|
|
conn.validate_session(self.client, self.session) # type: ignore[arg-type]
|
|
return use_cmd
|
|
|
|
def update_command(self, cmd: dict[str, Any]) -> None:
|
|
self._as_command = cmd, self.db
|
|
|
|
def as_command(
|
|
self, conn: _AgnosticConnection, apply_timeout: bool = False
|
|
) -> tuple[dict[str, Any], str]:
|
|
"""Return a getMore command document for this query."""
|
|
# See _Query.as_command for an explanation of this caching.
|
|
if self._as_command is not None:
|
|
return self._as_command
|
|
|
|
cmd: dict[str, Any] = _gen_get_more_command(
|
|
self.cursor_id,
|
|
self.coll,
|
|
self.ntoreturn,
|
|
self.max_await_time_ms,
|
|
self.comment,
|
|
conn,
|
|
)
|
|
if self.session:
|
|
self.session._apply_to(cmd, False, self.read_preference, conn) # type: ignore[arg-type]
|
|
conn.add_server_api(cmd)
|
|
conn.send_cluster_time(cmd, self.session, self.client) # type: ignore[arg-type]
|
|
# Support CSOT
|
|
if apply_timeout:
|
|
conn.apply_timeout(self.client, cmd=None) # type: ignore[arg-type]
|
|
self._as_command = cmd, self.db
|
|
return self._as_command
|
|
|
|
def get_message(
|
|
self, dummy0: Any, conn: _AgnosticConnection, use_cmd: bool = False
|
|
) -> Union[tuple[int, bytes, int], tuple[int, bytes]]:
|
|
"""Get a getmore message."""
|
|
ns = self.namespace()
|
|
ctx = conn.compression_context
|
|
|
|
if use_cmd:
|
|
spec = self.as_command(conn)[0]
|
|
if self.conn_mgr and self.exhaust:
|
|
flags = _OpMsg.EXHAUST_ALLOWED
|
|
else:
|
|
flags = 0
|
|
request_id, msg, size, _ = _op_msg(
|
|
flags, spec, self.db, None, self.codec_options, ctx=conn.compression_context
|
|
)
|
|
return request_id, msg, size
|
|
|
|
return _get_more(ns, self.ntoreturn, self.cursor_id, ctx)
|
|
|
|
|
|
class _RawBatchQuery(_Query):
|
|
def use_command(self, conn: _AgnosticConnection) -> bool:
|
|
# Compatibility checks.
|
|
super().use_command(conn)
|
|
if conn.max_wire_version >= 8:
|
|
# MongoDB 4.2+ supports exhaust over OP_MSG
|
|
return True
|
|
elif not self.exhaust:
|
|
return True
|
|
return False
|
|
|
|
|
|
class _RawBatchGetMore(_GetMore):
|
|
def use_command(self, conn: _AgnosticConnection) -> bool:
|
|
# Compatibility checks.
|
|
super().use_command(conn)
|
|
if conn.max_wire_version >= 8:
|
|
# MongoDB 4.2+ supports exhaust over OP_MSG
|
|
return True
|
|
elif not self.exhaust:
|
|
return True
|
|
return False
|
|
|
|
|
|
class _CursorAddress(tuple):
|
|
"""The server address (host, port) of a cursor, with namespace property."""
|
|
|
|
__namespace: Any
|
|
|
|
def __new__(cls, address: _Address, namespace: str) -> _CursorAddress:
|
|
self = tuple.__new__(cls, address)
|
|
self.__namespace = namespace
|
|
return self
|
|
|
|
@property
|
|
def namespace(self) -> str:
|
|
"""The namespace this cursor."""
|
|
return self.__namespace
|
|
|
|
def __hash__(self) -> int:
|
|
# Two _CursorAddress instances with different namespaces
|
|
# must not hash the same.
|
|
return ((*self, self.__namespace)).__hash__()
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
if isinstance(other, _CursorAddress):
|
|
return tuple(self) == tuple(other) and self.namespace == other.namespace
|
|
return NotImplemented
|
|
|
|
def __ne__(self, other: object) -> bool:
|
|
return not self == other
|