queues/venv/lib/python3.11/site-packages/pymongo/synchronous/aggregation.py
Egor Matveev 6c6a549aff
All checks were successful
Deploy Prod / Build (pull_request) Successful in 9s
Deploy Prod / Push (pull_request) Successful in 12s
Deploy Prod / Deploy prod (pull_request) Successful in 10s
fix
2024-12-28 22:48:16 +03:00

258 lines
9.2 KiB
Python

# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Perform aggregation operations on a collection or database."""
from __future__ import annotations
from collections.abc import Callable, Mapping, MutableMapping
from typing import TYPE_CHECKING, Any, Optional, Union
from pymongo import common
from pymongo.collation import validate_collation_or_none
from pymongo.errors import ConfigurationError
from pymongo.read_preferences import ReadPreference, _AggWritePref
if TYPE_CHECKING:
from pymongo.read_preferences import _ServerMode
from pymongo.synchronous.client_session import ClientSession
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.database import Database
from pymongo.synchronous.pool import Connection
from pymongo.synchronous.server import Server
from pymongo.typings import _DocumentType, _Pipeline
_IS_SYNC = True
class _AggregationCommand:
"""The internal abstract base class for aggregation cursors.
Should not be called directly by application developers. Use
:meth:`pymongo.collection.Collection.aggregate`, or
:meth:`pymongo.database.Database.aggregate` instead.
"""
def __init__(
self,
target: Union[Database, Collection],
cursor_class: type[CommandCursor],
pipeline: _Pipeline,
options: MutableMapping[str, Any],
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
user_fields: Optional[MutableMapping[str, Any]] = None,
result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None,
comment: Any = None,
) -> None:
if "explain" in options:
raise ConfigurationError(
"The explain option is not supported. Use Database.command instead."
)
self._target = target
pipeline = common.validate_list("pipeline", pipeline)
self._pipeline = pipeline
self._performs_write = False
if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]):
self._performs_write = True
common.validate_is_mapping("options", options)
if let is not None:
common.validate_is_mapping("let", let)
options["let"] = let
if comment is not None:
options["comment"] = comment
self._options = options
# This is the batchSize that will be used for setting the initial
# batchSize for the cursor, as well as the subsequent getMores.
self._batch_size = common.validate_non_negative_integer_or_none(
"batchSize", self._options.pop("batchSize", None)
)
# If the cursor option is already specified, avoid overriding it.
self._options.setdefault("cursor", {})
# If the pipeline performs a write, we ignore the initial batchSize
# since the server doesn't return results in this case.
if self._batch_size is not None and not self._performs_write:
self._options["cursor"]["batchSize"] = self._batch_size
self._cursor_class = cursor_class
self._explicit_session = explicit_session
self._user_fields = user_fields
self._result_processor = result_processor
self._collation = validate_collation_or_none(options.pop("collation", None))
self._max_await_time_ms = options.pop("maxAwaitTimeMS", None)
self._write_preference: Optional[_AggWritePref] = None
@property
def _aggregation_target(self) -> Union[str, int]:
"""The argument to pass to the aggregate command."""
raise NotImplementedError
@property
def _cursor_namespace(self) -> str:
"""The namespace in which the aggregate command is run."""
raise NotImplementedError
def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> Collection:
"""The Collection used for the aggregate command cursor."""
raise NotImplementedError
@property
def _database(self) -> Database:
"""The database against which the aggregation command is run."""
raise NotImplementedError
def get_read_preference(
self, session: Optional[ClientSession]
) -> Union[_AggWritePref, _ServerMode]:
if self._write_preference:
return self._write_preference
pref = self._target._read_preference_for(session)
if self._performs_write and pref != ReadPreference.PRIMARY:
self._write_preference = pref = _AggWritePref(pref) # type: ignore[assignment]
return pref
def get_cursor(
self,
session: Optional[ClientSession],
server: Server,
conn: Connection,
read_preference: _ServerMode,
) -> CommandCursor[_DocumentType]:
# Serialize command.
cmd = {"aggregate": self._aggregation_target, "pipeline": self._pipeline}
cmd.update(self._options)
# Apply this target's read concern if:
# readConcern has not been specified as a kwarg and either
# - server version is >= 4.2 or
# - server version is >= 3.2 and pipeline doesn't use $out
if ("readConcern" not in cmd) and (
not self._performs_write or (conn.max_wire_version >= 8)
):
read_concern = self._target.read_concern
else:
read_concern = None
# Apply this target's write concern if:
# writeConcern has not been specified as a kwarg and pipeline doesn't
# perform a write operation
if "writeConcern" not in cmd and self._performs_write:
write_concern = self._target._write_concern_for(session)
else:
write_concern = None
# Run command.
result = conn.command(
self._database.name,
cmd,
read_preference,
self._target.codec_options,
parse_write_concern_error=True,
read_concern=read_concern,
write_concern=write_concern,
collation=self._collation,
session=session,
client=self._database.client,
user_fields=self._user_fields,
)
if self._result_processor:
self._result_processor(result, conn)
# Extract cursor from result or mock/fake one if necessary.
if "cursor" in result:
cursor = result["cursor"]
else:
# Unacknowledged $out/$merge write. Fake a cursor.
cursor = {
"id": 0,
"firstBatch": result.get("result", []),
"ns": self._cursor_namespace,
}
# Create and return cursor instance.
cmd_cursor = self._cursor_class(
self._cursor_collection(cursor),
cursor,
conn.address,
batch_size=self._batch_size or 0,
max_await_time_ms=self._max_await_time_ms,
session=session,
explicit_session=self._explicit_session,
comment=self._options.get("comment"),
)
cmd_cursor._maybe_pin_connection(conn)
return cmd_cursor
class _CollectionAggregationCommand(_AggregationCommand):
_target: Collection
@property
def _aggregation_target(self) -> str:
return self._target.name
@property
def _cursor_namespace(self) -> str:
return self._target.full_name
def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection:
"""The Collection used for the aggregate command cursor."""
return self._target
@property
def _database(self) -> Database:
return self._target.database
class _CollectionRawAggregationCommand(_CollectionAggregationCommand):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# For raw-batches, we set the initial batchSize for the cursor to 0.
if not self._performs_write:
self._options["cursor"]["batchSize"] = 0
class _DatabaseAggregationCommand(_AggregationCommand):
_target: Database
@property
def _aggregation_target(self) -> int:
return 1
@property
def _cursor_namespace(self) -> str:
return f"{self._target.name}.$cmd.aggregate"
@property
def _database(self) -> Database:
return self._target
def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection:
"""The Collection used for the aggregate command cursor."""
# Collection level aggregate may not always return the "ns" field
# according to our MockupDB tests. Let's handle that case for db level
# aggregate too by defaulting to the <db>.$cmd.aggregate namespace.
_, collname = cursor.get("ns", self._cursor_namespace).split(".", 1)
return self._database[collname]