queues/venv/lib/python3.11/site-packages/pymongo/helpers_shared.py
Egor Matveev 6c6a549aff
All checks were successful
Deploy Prod / Build (pull_request) Successful in 9s
Deploy Prod / Push (pull_request) Successful in 12s
Deploy Prod / Deploy prod (pull_request) Successful in 10s
fix
2024-12-28 22:48:16 +03:00

328 lines
10 KiB
Python

# Copyright 2009-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bits and pieces used by the driver that don't really fit elsewhere."""
from __future__ import annotations
import sys
import traceback
from collections import abc
from typing import (
TYPE_CHECKING,
Any,
Container,
Iterable,
Mapping,
NoReturn,
Optional,
Sequence,
Union,
)
from pymongo import ASCENDING
from pymongo.errors import (
CursorNotFound,
DuplicateKeyError,
ExecutionTimeout,
NotPrimaryError,
OperationFailure,
WriteConcernError,
WriteError,
WTimeoutError,
_wtimeout_error,
)
from pymongo.hello import HelloCompat
if TYPE_CHECKING:
from pymongo.cursor_shared import _Hint
from pymongo.operations import _IndexList
from pymongo.typings import _DocumentOut
# From the SDAM spec, the "node is shutting down" codes.
_SHUTDOWN_CODES: frozenset = frozenset(
[
11600, # InterruptedAtShutdown
91, # ShutdownInProgress
]
)
# From the SDAM spec, the "not primary" error codes are combined with the
# "node is recovering" error codes (of which the "node is shutting down"
# errors are a subset).
_NOT_PRIMARY_CODES: frozenset = (
frozenset(
[
10058, # LegacyNotPrimary <=3.2 "not primary" error code
10107, # NotWritablePrimary
13435, # NotPrimaryNoSecondaryOk
11602, # InterruptedDueToReplStateChange
13436, # NotPrimaryOrSecondary
189, # PrimarySteppedDown
]
)
| _SHUTDOWN_CODES
)
# From the retryable writes spec.
_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset(
[
7, # HostNotFound
6, # HostUnreachable
89, # NetworkTimeout
9001, # SocketException
262, # ExceededTimeLimit
134, # ReadConcernMajorityNotAvailableYet
]
)
# Server code raised when re-authentication is required
_REAUTHENTICATION_REQUIRED_CODE: int = 391
# Server code raised when authentication fails.
_AUTHENTICATION_FAILURE_CODE: int = 18
# Note - to avoid bugs from forgetting which if these is all lowercase and
# which are camelCase, and at the same time avoid having to add a test for
# every command, use all lowercase here and test against command_name.lower().
_SENSITIVE_COMMANDS: set = {
"authenticate",
"saslstart",
"saslcontinue",
"getnonce",
"createuser",
"updateuser",
"copydbgetnonce",
"copydbsaslstart",
"copydb",
}
def _gen_index_name(keys: _IndexList) -> str:
"""Generate an index name from the set of fields it is over."""
return "_".join(["{}_{}".format(*item) for item in keys])
def _index_list(
key_or_list: _Hint, direction: Optional[Union[int, str]] = None
) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]:
"""Helper to generate a list of (key, direction) pairs.
Takes such a list, or a single key, or a single key and direction.
"""
if direction is not None:
if not isinstance(key_or_list, str):
raise TypeError("Expected a string and a direction")
return [(key_or_list, direction)]
else:
if isinstance(key_or_list, str):
return [(key_or_list, ASCENDING)]
elif isinstance(key_or_list, abc.ItemsView):
return list(key_or_list) # type: ignore[arg-type]
elif isinstance(key_or_list, abc.Mapping):
return list(key_or_list.items())
elif not isinstance(key_or_list, (list, tuple)):
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
values: list[tuple[str, int]] = []
for item in key_or_list:
if isinstance(item, str):
item = (item, ASCENDING) # noqa: PLW2901
values.append(item)
return values
def _index_document(index_list: _IndexList) -> dict[str, Any]:
"""Helper to generate an index specifying document.
Takes a list of (key, direction) pairs.
"""
if not isinstance(index_list, (list, tuple, abc.Mapping)):
raise TypeError(
"must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list)
)
if not len(index_list):
raise ValueError("key_or_list must not be empty")
index: dict[str, Any] = {}
if isinstance(index_list, abc.Mapping):
for key in index_list:
value = index_list[key]
_validate_index_key_pair(key, value)
index[key] = value
else:
for item in index_list:
if isinstance(item, str):
item = (item, ASCENDING) # noqa: PLW2901
key, value = item
_validate_index_key_pair(key, value)
index[key] = value
return index
def _validate_index_key_pair(key: Any, value: Any) -> None:
if not isinstance(key, str):
raise TypeError("first item in each key pair must be an instance of str")
if not isinstance(value, (str, int, abc.Mapping)):
raise TypeError(
"second item in each key pair must be 1, -1, "
"'2d', or another valid MongoDB index specifier."
)
def _check_command_response(
response: _DocumentOut,
max_wire_version: Optional[int],
allowable_errors: Optional[Container[Union[int, str]]] = None,
parse_write_concern_error: bool = False,
) -> None:
"""Check the response to a command for errors."""
if "ok" not in response:
# Server didn't recognize our message as a command.
raise OperationFailure(
response.get("$err"), # type: ignore[arg-type]
response.get("code"),
response,
max_wire_version,
)
if parse_write_concern_error and "writeConcernError" in response:
_error = response["writeConcernError"]
_labels = response.get("errorLabels")
if _labels:
_error.update({"errorLabels": _labels})
_raise_write_concern_error(_error)
if response["ok"]:
return
details = response
# Mongos returns the error details in a 'raw' object
# for some errors.
if "raw" in response:
for shard in response["raw"].values():
# Grab the first non-empty raw error from a shard.
if shard.get("errmsg") and not shard.get("ok"):
details = shard
break
errmsg = details["errmsg"]
code = details.get("code")
# For allowable errors, only check for error messages when the code is not
# included.
if allowable_errors:
if code is not None:
if code in allowable_errors:
return
elif errmsg in allowable_errors:
return
# Server is "not primary" or "recovering"
if code is not None:
if code in _NOT_PRIMARY_CODES:
raise NotPrimaryError(errmsg, response)
elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg:
raise NotPrimaryError(errmsg, response)
# Other errors
# findAndModify with upsert can raise duplicate key error
if code in (11000, 11001, 12582):
raise DuplicateKeyError(errmsg, code, response, max_wire_version)
elif code == 50:
raise ExecutionTimeout(errmsg, code, response, max_wire_version)
elif code == 43:
raise CursorNotFound(errmsg, code, response, max_wire_version)
raise OperationFailure(errmsg, code, response, max_wire_version)
def _raise_last_write_error(write_errors: list[Any]) -> NoReturn:
# If the last batch had multiple errors only report
# the last error to emulate continue_on_error.
error = write_errors[-1]
if error.get("code") == 11000:
raise DuplicateKeyError(error.get("errmsg"), 11000, error)
raise WriteError(error.get("errmsg"), error.get("code"), error)
def _raise_write_concern_error(error: Any) -> NoReturn:
if _wtimeout_error(error):
# Make sure we raise WTimeoutError
raise WTimeoutError(error.get("errmsg"), error.get("code"), error)
raise WriteConcernError(error.get("errmsg"), error.get("code"), error)
def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]:
"""Return the writeConcernError or None."""
wce = result.get("writeConcernError")
if wce:
# The server reports errorLabels at the top level but it's more
# convenient to attach it to the writeConcernError doc itself.
error_labels = result.get("errorLabels")
if error_labels:
# Copy to avoid changing the original document.
wce = wce.copy()
wce["errorLabels"] = error_labels
return wce
def _check_write_command_response(result: Mapping[str, Any]) -> None:
"""Backward compatibility helper for write command error handling."""
# Prefer write errors over write concern errors
write_errors = result.get("writeErrors")
if write_errors:
_raise_last_write_error(write_errors)
wce = _get_wce_doc(result)
if wce:
_raise_write_concern_error(wce)
def _fields_list_to_dict(
fields: Union[Mapping[str, Any], Iterable[str]], option_name: str
) -> Mapping[str, Any]:
"""Takes a sequence of field names and returns a matching dictionary.
["a", "b"] becomes {"a": 1, "b": 1}
and
["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1}
"""
if isinstance(fields, abc.Mapping):
return fields
if isinstance(fields, (abc.Sequence, abc.Set)):
if not all(isinstance(field, str) for field in fields):
raise TypeError(f"{option_name} must be a list of key names, each an instance of str")
return dict.fromkeys(fields, 1)
raise TypeError(f"{option_name} must be a mapping or list of key names")
def _handle_exception() -> None:
"""Print exceptions raised by subscribers to stderr."""
# Heavily influenced by logging.Handler.handleError.
# See note here:
# https://docs.python.org/3.4/library/sys.html#sys.__stderr__
if sys.stderr:
einfo = sys.exc_info()
try:
traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr)
except OSError:
pass
finally:
del einfo