328 lines
10 KiB
Python
328 lines
10 KiB
Python
# Copyright 2009-present MongoDB, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Bits and pieces used by the driver that don't really fit elsewhere."""
|
|
from __future__ import annotations
|
|
|
|
import sys
|
|
import traceback
|
|
from collections import abc
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Container,
|
|
Iterable,
|
|
Mapping,
|
|
NoReturn,
|
|
Optional,
|
|
Sequence,
|
|
Union,
|
|
)
|
|
|
|
from pymongo import ASCENDING
|
|
from pymongo.errors import (
|
|
CursorNotFound,
|
|
DuplicateKeyError,
|
|
ExecutionTimeout,
|
|
NotPrimaryError,
|
|
OperationFailure,
|
|
WriteConcernError,
|
|
WriteError,
|
|
WTimeoutError,
|
|
_wtimeout_error,
|
|
)
|
|
from pymongo.hello import HelloCompat
|
|
|
|
if TYPE_CHECKING:
|
|
from pymongo.cursor_shared import _Hint
|
|
from pymongo.operations import _IndexList
|
|
from pymongo.typings import _DocumentOut
|
|
|
|
|
|
# From the SDAM spec, the "node is shutting down" codes.
|
|
|
|
_SHUTDOWN_CODES: frozenset = frozenset(
|
|
[
|
|
11600, # InterruptedAtShutdown
|
|
91, # ShutdownInProgress
|
|
]
|
|
)
|
|
# From the SDAM spec, the "not primary" error codes are combined with the
|
|
# "node is recovering" error codes (of which the "node is shutting down"
|
|
# errors are a subset).
|
|
_NOT_PRIMARY_CODES: frozenset = (
|
|
frozenset(
|
|
[
|
|
10058, # LegacyNotPrimary <=3.2 "not primary" error code
|
|
10107, # NotWritablePrimary
|
|
13435, # NotPrimaryNoSecondaryOk
|
|
11602, # InterruptedDueToReplStateChange
|
|
13436, # NotPrimaryOrSecondary
|
|
189, # PrimarySteppedDown
|
|
]
|
|
)
|
|
| _SHUTDOWN_CODES
|
|
)
|
|
# From the retryable writes spec.
|
|
_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset(
|
|
[
|
|
7, # HostNotFound
|
|
6, # HostUnreachable
|
|
89, # NetworkTimeout
|
|
9001, # SocketException
|
|
262, # ExceededTimeLimit
|
|
134, # ReadConcernMajorityNotAvailableYet
|
|
]
|
|
)
|
|
|
|
# Server code raised when re-authentication is required
|
|
_REAUTHENTICATION_REQUIRED_CODE: int = 391
|
|
|
|
# Server code raised when authentication fails.
|
|
_AUTHENTICATION_FAILURE_CODE: int = 18
|
|
|
|
# Note - to avoid bugs from forgetting which if these is all lowercase and
|
|
# which are camelCase, and at the same time avoid having to add a test for
|
|
# every command, use all lowercase here and test against command_name.lower().
|
|
_SENSITIVE_COMMANDS: set = {
|
|
"authenticate",
|
|
"saslstart",
|
|
"saslcontinue",
|
|
"getnonce",
|
|
"createuser",
|
|
"updateuser",
|
|
"copydbgetnonce",
|
|
"copydbsaslstart",
|
|
"copydb",
|
|
}
|
|
|
|
|
|
def _gen_index_name(keys: _IndexList) -> str:
|
|
"""Generate an index name from the set of fields it is over."""
|
|
return "_".join(["{}_{}".format(*item) for item in keys])
|
|
|
|
|
|
def _index_list(
|
|
key_or_list: _Hint, direction: Optional[Union[int, str]] = None
|
|
) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]:
|
|
"""Helper to generate a list of (key, direction) pairs.
|
|
|
|
Takes such a list, or a single key, or a single key and direction.
|
|
"""
|
|
if direction is not None:
|
|
if not isinstance(key_or_list, str):
|
|
raise TypeError("Expected a string and a direction")
|
|
return [(key_or_list, direction)]
|
|
else:
|
|
if isinstance(key_or_list, str):
|
|
return [(key_or_list, ASCENDING)]
|
|
elif isinstance(key_or_list, abc.ItemsView):
|
|
return list(key_or_list) # type: ignore[arg-type]
|
|
elif isinstance(key_or_list, abc.Mapping):
|
|
return list(key_or_list.items())
|
|
elif not isinstance(key_or_list, (list, tuple)):
|
|
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
|
|
values: list[tuple[str, int]] = []
|
|
for item in key_or_list:
|
|
if isinstance(item, str):
|
|
item = (item, ASCENDING) # noqa: PLW2901
|
|
values.append(item)
|
|
return values
|
|
|
|
|
|
def _index_document(index_list: _IndexList) -> dict[str, Any]:
|
|
"""Helper to generate an index specifying document.
|
|
|
|
Takes a list of (key, direction) pairs.
|
|
"""
|
|
if not isinstance(index_list, (list, tuple, abc.Mapping)):
|
|
raise TypeError(
|
|
"must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list)
|
|
)
|
|
if not len(index_list):
|
|
raise ValueError("key_or_list must not be empty")
|
|
|
|
index: dict[str, Any] = {}
|
|
|
|
if isinstance(index_list, abc.Mapping):
|
|
for key in index_list:
|
|
value = index_list[key]
|
|
_validate_index_key_pair(key, value)
|
|
index[key] = value
|
|
else:
|
|
for item in index_list:
|
|
if isinstance(item, str):
|
|
item = (item, ASCENDING) # noqa: PLW2901
|
|
key, value = item
|
|
_validate_index_key_pair(key, value)
|
|
index[key] = value
|
|
return index
|
|
|
|
|
|
def _validate_index_key_pair(key: Any, value: Any) -> None:
|
|
if not isinstance(key, str):
|
|
raise TypeError("first item in each key pair must be an instance of str")
|
|
if not isinstance(value, (str, int, abc.Mapping)):
|
|
raise TypeError(
|
|
"second item in each key pair must be 1, -1, "
|
|
"'2d', or another valid MongoDB index specifier."
|
|
)
|
|
|
|
|
|
def _check_command_response(
|
|
response: _DocumentOut,
|
|
max_wire_version: Optional[int],
|
|
allowable_errors: Optional[Container[Union[int, str]]] = None,
|
|
parse_write_concern_error: bool = False,
|
|
) -> None:
|
|
"""Check the response to a command for errors."""
|
|
if "ok" not in response:
|
|
# Server didn't recognize our message as a command.
|
|
raise OperationFailure(
|
|
response.get("$err"), # type: ignore[arg-type]
|
|
response.get("code"),
|
|
response,
|
|
max_wire_version,
|
|
)
|
|
|
|
if parse_write_concern_error and "writeConcernError" in response:
|
|
_error = response["writeConcernError"]
|
|
_labels = response.get("errorLabels")
|
|
if _labels:
|
|
_error.update({"errorLabels": _labels})
|
|
_raise_write_concern_error(_error)
|
|
|
|
if response["ok"]:
|
|
return
|
|
|
|
details = response
|
|
# Mongos returns the error details in a 'raw' object
|
|
# for some errors.
|
|
if "raw" in response:
|
|
for shard in response["raw"].values():
|
|
# Grab the first non-empty raw error from a shard.
|
|
if shard.get("errmsg") and not shard.get("ok"):
|
|
details = shard
|
|
break
|
|
|
|
errmsg = details["errmsg"]
|
|
code = details.get("code")
|
|
|
|
# For allowable errors, only check for error messages when the code is not
|
|
# included.
|
|
if allowable_errors:
|
|
if code is not None:
|
|
if code in allowable_errors:
|
|
return
|
|
elif errmsg in allowable_errors:
|
|
return
|
|
|
|
# Server is "not primary" or "recovering"
|
|
if code is not None:
|
|
if code in _NOT_PRIMARY_CODES:
|
|
raise NotPrimaryError(errmsg, response)
|
|
elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg:
|
|
raise NotPrimaryError(errmsg, response)
|
|
|
|
# Other errors
|
|
# findAndModify with upsert can raise duplicate key error
|
|
if code in (11000, 11001, 12582):
|
|
raise DuplicateKeyError(errmsg, code, response, max_wire_version)
|
|
elif code == 50:
|
|
raise ExecutionTimeout(errmsg, code, response, max_wire_version)
|
|
elif code == 43:
|
|
raise CursorNotFound(errmsg, code, response, max_wire_version)
|
|
|
|
raise OperationFailure(errmsg, code, response, max_wire_version)
|
|
|
|
|
|
def _raise_last_write_error(write_errors: list[Any]) -> NoReturn:
|
|
# If the last batch had multiple errors only report
|
|
# the last error to emulate continue_on_error.
|
|
error = write_errors[-1]
|
|
if error.get("code") == 11000:
|
|
raise DuplicateKeyError(error.get("errmsg"), 11000, error)
|
|
raise WriteError(error.get("errmsg"), error.get("code"), error)
|
|
|
|
|
|
def _raise_write_concern_error(error: Any) -> NoReturn:
|
|
if _wtimeout_error(error):
|
|
# Make sure we raise WTimeoutError
|
|
raise WTimeoutError(error.get("errmsg"), error.get("code"), error)
|
|
raise WriteConcernError(error.get("errmsg"), error.get("code"), error)
|
|
|
|
|
|
def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]:
|
|
"""Return the writeConcernError or None."""
|
|
wce = result.get("writeConcernError")
|
|
if wce:
|
|
# The server reports errorLabels at the top level but it's more
|
|
# convenient to attach it to the writeConcernError doc itself.
|
|
error_labels = result.get("errorLabels")
|
|
if error_labels:
|
|
# Copy to avoid changing the original document.
|
|
wce = wce.copy()
|
|
wce["errorLabels"] = error_labels
|
|
return wce
|
|
|
|
|
|
def _check_write_command_response(result: Mapping[str, Any]) -> None:
|
|
"""Backward compatibility helper for write command error handling."""
|
|
# Prefer write errors over write concern errors
|
|
write_errors = result.get("writeErrors")
|
|
if write_errors:
|
|
_raise_last_write_error(write_errors)
|
|
|
|
wce = _get_wce_doc(result)
|
|
if wce:
|
|
_raise_write_concern_error(wce)
|
|
|
|
|
|
def _fields_list_to_dict(
|
|
fields: Union[Mapping[str, Any], Iterable[str]], option_name: str
|
|
) -> Mapping[str, Any]:
|
|
"""Takes a sequence of field names and returns a matching dictionary.
|
|
|
|
["a", "b"] becomes {"a": 1, "b": 1}
|
|
|
|
and
|
|
|
|
["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1}
|
|
"""
|
|
if isinstance(fields, abc.Mapping):
|
|
return fields
|
|
|
|
if isinstance(fields, (abc.Sequence, abc.Set)):
|
|
if not all(isinstance(field, str) for field in fields):
|
|
raise TypeError(f"{option_name} must be a list of key names, each an instance of str")
|
|
return dict.fromkeys(fields, 1)
|
|
|
|
raise TypeError(f"{option_name} must be a mapping or list of key names")
|
|
|
|
|
|
def _handle_exception() -> None:
|
|
"""Print exceptions raised by subscribers to stderr."""
|
|
# Heavily influenced by logging.Handler.handleError.
|
|
|
|
# See note here:
|
|
# https://docs.python.org/3.4/library/sys.html#sys.__stderr__
|
|
if sys.stderr:
|
|
einfo = sys.exc_info()
|
|
try:
|
|
traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr)
|
|
except OSError:
|
|
pass
|
|
finally:
|
|
del einfo
|