queues/venv/lib/python3.11/site-packages/pymongo/synchronous/topology.py
Egor Matveev 6c6a549aff
All checks were successful
Deploy Prod / Build (pull_request) Successful in 9s
Deploy Prod / Push (pull_request) Successful in 12s
Deploy Prod / Deploy prod (pull_request) Successful in 10s
fix
2024-12-28 22:48:16 +03:00

1098 lines
43 KiB
Python

# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Internal class to monitor a topology of one or more servers."""
from __future__ import annotations
import logging
import os
import queue
import random
import sys
import time
import warnings
import weakref
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
from pymongo import _csot, common, helpers_shared
from pymongo.errors import (
ConnectionFailure,
InvalidOperation,
NetworkTimeout,
NotPrimaryError,
OperationFailure,
PyMongoError,
ServerSelectionTimeoutError,
WriteError,
)
from pymongo.hello import Hello
from pymongo.lock import _create_lock, _Lock
from pymongo.logger import (
_SDAM_LOGGER,
_SERVER_SELECTION_LOGGER,
_debug_log,
_SDAMStatusMessage,
_ServerSelectionStatusMessage,
)
from pymongo.pool_options import PoolOptions
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import (
Selection,
any_server_selector,
arbiter_server_selector,
secondary_server_selector,
writable_server_selector,
)
from pymongo.synchronous import periodic_executor
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.synchronous.monitor import SrvMonitor
from pymongo.synchronous.pool import Pool
from pymongo.synchronous.server import Server
from pymongo.topology_description import (
SRV_POLLING_TOPOLOGIES,
TOPOLOGY_TYPE,
TopologyDescription,
_updated_topology_description_srv_polling,
updated_topology_description,
)
if TYPE_CHECKING:
from bson import ObjectId
from pymongo.synchronous.settings import TopologySettings
from pymongo.typings import ClusterTime, _Address
_IS_SYNC = True
_pymongo_dir = str(Path(__file__).parent)
def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool:
q = queue_ref()
if not q:
return False # Cancel PeriodicExecutor.
while True:
try:
event = q.get_nowait()
except queue.Empty:
break
else:
fn, args = event
fn(*args)
return True # Continue PeriodicExecutor.
class Topology:
"""Monitor a topology of one or more servers."""
def __init__(self, topology_settings: TopologySettings):
self._topology_id = topology_settings._topology_id
self._listeners = topology_settings._pool_options._event_listeners
self._publish_server = self._listeners is not None and self._listeners.enabled_for_server
self._publish_tp = self._listeners is not None and self._listeners.enabled_for_topology
# Create events queue if there are publishers.
self._events = None
self.__events_executor: Any = None
if self._publish_server or self._publish_tp:
self._events = queue.Queue(maxsize=100)
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_SDAM_LOGGER,
topologyId=self._topology_id,
message=_SDAMStatusMessage.START_TOPOLOGY,
)
if self._publish_tp:
assert self._events is not None
self._events.put((self._listeners.publish_topology_opened, (self._topology_id,)))
self._settings = topology_settings
topology_description = TopologyDescription(
topology_settings.get_topology_type(),
topology_settings.get_server_descriptions(),
topology_settings.replica_set_name,
None,
None,
topology_settings,
)
self._description = topology_description
initial_td = TopologyDescription(
TOPOLOGY_TYPE.Unknown, {}, None, None, None, self._settings
)
if self._publish_tp:
assert self._events is not None
self._events.put(
(
self._listeners.publish_topology_description_changed,
(initial_td, self._description, self._topology_id),
)
)
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_SDAM_LOGGER,
topologyId=self._topology_id,
previousDescription=initial_td,
newDescription=self._description,
message=_SDAMStatusMessage.TOPOLOGY_CHANGE,
)
for seed in topology_settings.seeds:
if self._publish_server:
assert self._events is not None
self._events.put((self._listeners.publish_server_opened, (seed, self._topology_id)))
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_SDAM_LOGGER,
topologyId=self._topology_id,
serverHost=seed[0],
serverPort=seed[1],
message=_SDAMStatusMessage.START_SERVER,
)
# Store the seed list to help diagnose errors in _error_message().
self._seed_addresses = list(topology_description.server_descriptions())
self._opened = False
self._closed = False
_lock = _create_lock()
self._lock = _Lock(_lock)
self._condition = self._settings.condition_class(_lock)
self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None
self._session_pool = _ServerSessionPool()
if self._publish_server or self._publish_tp:
assert self._events is not None
weak: weakref.ReferenceType[queue.Queue]
def target() -> bool:
return process_events_queue(weak)
executor = periodic_executor.PeriodicExecutor(
interval=common.EVENTS_QUEUE_FREQUENCY,
min_interval=common.MIN_HEARTBEAT_INTERVAL,
target=target,
name="pymongo_events_thread",
)
# We strongly reference the executor and it weakly references
# the queue via this closure. When the topology is freed, stop
# the executor soon.
weak = weakref.ref(self._events, executor.close)
self.__events_executor = executor
executor.open()
self._srv_monitor = None
if self._settings.fqdn is not None and not self._settings.load_balanced:
self._srv_monitor = SrvMonitor(self, self._settings)
def open(self) -> None:
"""Start monitoring, or restart after a fork.
No effect if called multiple times.
.. warning:: Topology is shared among multiple threads and is protected
by mutual exclusion. Using Topology from a process other than the one
that initialized it will emit a warning and may result in deadlock. To
prevent this from happening, MongoClient must be created after any
forking.
"""
pid = os.getpid()
if self._pid is None:
self._pid = pid
elif pid != self._pid:
self._pid = pid
if sys.version_info[:2] >= (3, 12):
kwargs = {"skip_file_prefixes": (_pymongo_dir,)}
else:
kwargs = {"stacklevel": 6}
# Ignore B028 warning for missing stacklevel.
warnings.warn( # type: ignore[call-overload] # noqa: B028
"MongoClient opened before fork. May not be entirely fork-safe, "
"proceed with caution. See PyMongo's documentation for details: "
"https://www.mongodb.com/docs/languages/"
"python/pymongo-driver/current/faq/"
"#is-pymongo-fork-safe-",
**kwargs,
)
with self._lock:
# Close servers and clear the pools.
for server in self._servers.values():
server.close()
# Reset the session pool to avoid duplicate sessions in
# the child process.
self._session_pool.reset()
with self._lock:
self._ensure_opened()
def get_server_selection_timeout(self) -> float:
# CSOT: use remaining timeout when set.
timeout = _csot.remaining()
if timeout is None:
return self._settings.server_selection_timeout
return timeout
def select_servers(
self,
selector: Callable[[Selection], Selection],
operation: str,
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
operation_id: Optional[int] = None,
) -> list[Server]:
"""Return a list of Servers matching selector, or time out.
:param selector: function that takes a list of Servers and returns
a subset of them.
:param operation: The name of the operation that the server is being selected for.
:param server_selection_timeout: maximum seconds to wait.
If not provided, the default value common.SERVER_SELECTION_TIMEOUT
is used.
:param address: optional server address to select.
Calls self.open() if needed.
Raises exc:`ServerSelectionTimeoutError` after
`server_selection_timeout` if no matching servers are found.
"""
if server_selection_timeout is None:
server_timeout = self.get_server_selection_timeout()
else:
server_timeout = server_selection_timeout
with self._lock:
server_descriptions = self._select_servers_loop(
selector, server_timeout, operation, operation_id, address
)
return [
cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions
]
def _select_servers_loop(
self,
selector: Callable[[Selection], Selection],
timeout: float,
operation: str,
operation_id: Optional[int],
address: Optional[_Address],
) -> list[ServerDescription]:
"""select_servers() guts. Hold the lock when calling this."""
now = time.monotonic()
end_time = now + timeout
logged_waiting = False
if _SERVER_SELECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_SERVER_SELECTION_LOGGER,
message=_ServerSelectionStatusMessage.STARTED,
selector=selector,
operation=operation,
operationId=operation_id,
topologyDescription=self.description,
clientId=self.description._topology_settings._topology_id,
)
server_descriptions = self._description.apply_selector(
selector, address, custom_selector=self._settings.server_selector
)
while not server_descriptions:
# No suitable servers.
if timeout == 0 or now > end_time:
if _SERVER_SELECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_SERVER_SELECTION_LOGGER,
message=_ServerSelectionStatusMessage.FAILED,
selector=selector,
operation=operation,
operationId=operation_id,
topologyDescription=self.description,
clientId=self.description._topology_settings._topology_id,
failure=self._error_message(selector),
)
raise ServerSelectionTimeoutError(
f"{self._error_message(selector)}, Timeout: {timeout}s, Topology Description: {self.description!r}"
)
if not logged_waiting:
_debug_log(
_SERVER_SELECTION_LOGGER,
message=_ServerSelectionStatusMessage.WAITING,
selector=selector,
operation=operation,
operationId=operation_id,
topologyDescription=self.description,
clientId=self.description._topology_settings._topology_id,
remainingTimeMS=int(end_time - time.monotonic()),
)
logged_waiting = True
self._ensure_opened()
self._request_check_all()
# Release the lock and wait for the topology description to
# change, or for a timeout. We won't miss any changes that
# came after our most recent apply_selector call, since we've
# held the lock until now.
self._condition.wait(common.MIN_HEARTBEAT_INTERVAL)
self._description.check_compatible()
now = time.monotonic()
server_descriptions = self._description.apply_selector(
selector, address, custom_selector=self._settings.server_selector
)
self._description.check_compatible()
return server_descriptions
def _select_server(
self,
selector: Callable[[Selection], Selection],
operation: str,
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
deprioritized_servers: Optional[list[Server]] = None,
operation_id: Optional[int] = None,
) -> Server:
servers = self.select_servers(
selector, operation, server_selection_timeout, address, operation_id
)
servers = _filter_servers(servers, deprioritized_servers)
if len(servers) == 1:
return servers[0]
server1, server2 = random.sample(servers, 2)
if server1.pool.operation_count <= server2.pool.operation_count:
return server1
else:
return server2
def select_server(
self,
selector: Callable[[Selection], Selection],
operation: str,
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
deprioritized_servers: Optional[list[Server]] = None,
operation_id: Optional[int] = None,
) -> Server:
"""Like select_servers, but choose a random server if several match."""
server = self._select_server(
selector,
operation,
server_selection_timeout,
address,
deprioritized_servers,
operation_id=operation_id,
)
if _csot.get_timeout():
_csot.set_rtt(server.description.min_round_trip_time)
if _SERVER_SELECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_SERVER_SELECTION_LOGGER,
message=_ServerSelectionStatusMessage.SUCCEEDED,
selector=selector,
operation=operation,
operationId=operation_id,
topologyDescription=self.description,
clientId=self.description._topology_settings._topology_id,
serverHost=server.description.address[0],
serverPort=server.description.address[1],
)
return server
def select_server_by_address(
self,
address: _Address,
operation: str,
server_selection_timeout: Optional[int] = None,
operation_id: Optional[int] = None,
) -> Server:
"""Return a Server for "address", reconnecting if necessary.
If the server's type is not known, request an immediate check of all
servers. Time out after "server_selection_timeout" if the server
cannot be reached.
:param address: A (host, port) pair.
:param operation: The name of the operation that the server is being selected for.
:param server_selection_timeout: maximum seconds to wait.
If not provided, the default value
common.SERVER_SELECTION_TIMEOUT is used.
:param operation_id: The unique id of the current operation being performed. Defaults to None if not provided.
Calls self.open() if needed.
Raises exc:`ServerSelectionTimeoutError` after
`server_selection_timeout` if no matching servers are found.
"""
return self.select_server(
any_server_selector,
operation,
server_selection_timeout,
address,
operation_id=operation_id,
)
def _process_change(
self,
server_description: ServerDescription,
reset_pool: bool = False,
interrupt_connections: bool = False,
) -> None:
"""Process a new ServerDescription on an opened topology.
Hold the lock when calling this.
"""
td_old = self._description
sd_old = td_old._server_descriptions[server_description.address]
if _is_stale_server_description(sd_old, server_description):
# This is a stale hello response. Ignore it.
return
new_td = updated_topology_description(self._description, server_description)
# CMAP: Ensure the pool is "ready" when the server is selectable.
if server_description.is_readable or (
server_description.is_server_type_known and new_td.topology_type == TOPOLOGY_TYPE.Single
):
server = self._servers.get(server_description.address)
if server:
server.pool.ready()
suppress_event = sd_old == server_description
if self._publish_server and not suppress_event:
assert self._events is not None
self._events.put(
(
self._listeners.publish_server_description_changed,
(sd_old, server_description, server_description.address, self._topology_id),
)
)
self._description = new_td
self._update_servers()
self._receive_cluster_time_no_lock(server_description.cluster_time)
if self._publish_tp and not suppress_event:
assert self._events is not None
self._events.put(
(
self._listeners.publish_topology_description_changed,
(td_old, self._description, self._topology_id),
)
)
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG) and not suppress_event:
_debug_log(
_SDAM_LOGGER,
topologyId=self._topology_id,
previousDescription=td_old,
newDescription=self._description,
message=_SDAMStatusMessage.TOPOLOGY_CHANGE,
)
# Shutdown SRV polling for unsupported cluster types.
# This is only applicable if the old topology was Unknown, and the
# new one is something other than Unknown or Sharded.
if self._srv_monitor and (
td_old.topology_type == TOPOLOGY_TYPE.Unknown
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
):
self._srv_monitor.close()
# Clear the pool from a failed heartbeat.
if reset_pool:
server = self._servers.get(server_description.address)
if server:
server.pool.reset(interrupt_connections=interrupt_connections)
# Wake anything waiting in select_servers().
self._condition.notify_all()
def on_change(
self,
server_description: ServerDescription,
reset_pool: bool = False,
interrupt_connections: bool = False,
) -> None:
"""Process a new ServerDescription after an hello call completes."""
# We do no I/O holding the lock.
with self._lock:
# Monitors may continue working on hello calls for some time
# after a call to Topology.close, so this method may be called at
# any time. Ensure the topology is open before processing the
# change.
# Any monitored server was definitely in the topology description
# once. Check if it's still in the description or if some state-
# change removed it. E.g., we got a host list from the primary
# that didn't include this server.
if self._opened and self._description.has_server(server_description.address):
self._process_change(server_description, reset_pool, interrupt_connections)
def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None:
"""Process a new seedlist on an opened topology.
Hold the lock when calling this.
"""
td_old = self._description
if td_old.topology_type not in SRV_POLLING_TOPOLOGIES:
return
self._description = _updated_topology_description_srv_polling(self._description, seedlist)
self._update_servers()
if self._publish_tp:
assert self._events is not None
self._events.put(
(
self._listeners.publish_topology_description_changed,
(td_old, self._description, self._topology_id),
)
)
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_SDAM_LOGGER,
topologyId=self._topology_id,
previousDescription=td_old,
newDescription=self._description,
message=_SDAMStatusMessage.TOPOLOGY_CHANGE,
)
def on_srv_update(self, seedlist: list[tuple[str, Any]]) -> None:
"""Process a new list of nodes obtained from scanning SRV records."""
# We do no I/O holding the lock.
with self._lock:
if self._opened:
self._process_srv_update(seedlist)
def get_server_by_address(self, address: _Address) -> Optional[Server]:
"""Get a Server or None.
Returns the current version of the server immediately, even if it's
Unknown or absent from the topology. Only use this in unittests.
In driver code, use select_server_by_address, since then you're
assured a recent view of the server's type and wire protocol version.
"""
return self._servers.get(address)
def has_server(self, address: _Address) -> bool:
return address in self._servers
def get_primary(self) -> Optional[_Address]:
"""Return primary's address or None."""
# Implemented here in Topology instead of MongoClient, so it can lock.
with self._lock:
topology_type = self._description.topology_type
if topology_type != TOPOLOGY_TYPE.ReplicaSetWithPrimary:
return None
return writable_server_selector(self._new_selection())[0].address
def _get_replica_set_members(self, selector: Callable[[Selection], Selection]) -> set[_Address]:
"""Return set of replica set member addresses."""
# Implemented here in Topology instead of MongoClient, so it can lock.
with self._lock:
topology_type = self._description.topology_type
if topology_type not in (
TOPOLOGY_TYPE.ReplicaSetWithPrimary,
TOPOLOGY_TYPE.ReplicaSetNoPrimary,
):
return set()
return {sd.address for sd in iter(selector(self._new_selection()))}
def get_secondaries(self) -> set[_Address]:
"""Return set of secondary addresses."""
return self._get_replica_set_members(secondary_server_selector)
def get_arbiters(self) -> set[_Address]:
"""Return set of arbiter addresses."""
return self._get_replica_set_members(arbiter_server_selector)
def max_cluster_time(self) -> Optional[ClusterTime]:
"""Return a document, the highest seen $clusterTime."""
return self._max_cluster_time
def _receive_cluster_time_no_lock(self, cluster_time: Optional[Mapping[str, Any]]) -> None:
# Driver Sessions Spec: "Whenever a driver receives a cluster time from
# a server it MUST compare it to the current highest seen cluster time
# for the deployment. If the new cluster time is higher than the
# highest seen cluster time it MUST become the new highest seen cluster
# time. Two cluster times are compared using only the BsonTimestamp
# value of the clusterTime embedded field."
if cluster_time:
# ">" uses bson.timestamp.Timestamp's comparison operator.
if (
not self._max_cluster_time
or cluster_time["clusterTime"] > self._max_cluster_time["clusterTime"]
):
self._max_cluster_time = cluster_time
def receive_cluster_time(self, cluster_time: Optional[Mapping[str, Any]]) -> None:
with self._lock:
self._receive_cluster_time_no_lock(cluster_time)
def request_check_all(self, wait_time: int = 5) -> None:
"""Wake all monitors, wait for at least one to check its server."""
with self._lock:
self._request_check_all()
self._condition.wait(wait_time)
def data_bearing_servers(self) -> list[ServerDescription]:
"""Return a list of all data-bearing servers.
This includes any server that might be selected for an operation.
"""
if self._description.topology_type == TOPOLOGY_TYPE.Single:
return self._description.known_servers
return self._description.readable_servers
def update_pool(self) -> None:
# Remove any stale sockets and add new sockets if pool is too small.
servers = []
with self._lock:
# Only update pools for data-bearing servers.
for sd in self.data_bearing_servers():
server = self._servers[sd.address]
servers.append((server, server.pool.gen.get_overall()))
for server, generation in servers:
try:
server.pool.remove_stale_sockets(generation)
except PyMongoError as exc:
ctx = _ErrorContext(exc, 0, generation, False, None)
self.handle_error(server.description.address, ctx)
raise
def close(self) -> None:
"""Clear pools and terminate monitors. Topology does not reopen on
demand. Any further operations will raise
:exc:`~.errors.InvalidOperation`.
"""
with self._lock:
old_td = self._description
for server in self._servers.values():
server.close()
# Mark all servers Unknown.
self._description = self._description.reset()
for address, sd in self._description.server_descriptions().items():
if address in self._servers:
self._servers[address].description = sd
# Stop SRV polling thread.
if self._srv_monitor:
self._srv_monitor.close()
self._opened = False
self._closed = True
# Publish only after releasing the lock.
if self._publish_tp:
assert self._events is not None
self._description = TopologyDescription(
TOPOLOGY_TYPE.Unknown,
{},
self._description.replica_set_name,
self._description.max_set_version,
self._description.max_election_id,
self._description._topology_settings,
)
self._events.put(
(
self._listeners.publish_topology_description_changed,
(
old_td,
self._description,
self._topology_id,
),
)
)
self._events.put((self._listeners.publish_topology_closed, (self._topology_id,)))
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_SDAM_LOGGER,
topologyId=self._topology_id,
previousDescription=old_td,
newDescription=self._description,
message=_SDAMStatusMessage.TOPOLOGY_CHANGE,
)
_debug_log(
_SDAM_LOGGER, topologyId=self._topology_id, message=_SDAMStatusMessage.STOP_TOPOLOGY
)
if self._publish_server or self._publish_tp:
# Make sure the events executor thread is fully closed before publishing the remaining events
self.__events_executor.close()
self.__events_executor.join(1)
process_events_queue(weakref.ref(self._events)) # type: ignore[arg-type]
@property
def description(self) -> TopologyDescription:
return self._description
def pop_all_sessions(self) -> list[_ServerSession]:
"""Pop all session ids from the pool."""
return self._session_pool.pop_all()
def get_server_session(self, session_timeout_minutes: Optional[int]) -> _ServerSession:
"""Start or resume a server session, or raise ConfigurationError."""
return self._session_pool.get_server_session(session_timeout_minutes)
def return_server_session(self, server_session: _ServerSession) -> None:
self._session_pool.return_server_session(server_session)
def _new_selection(self) -> Selection:
"""A Selection object, initially including all known servers.
Hold the lock when calling this.
"""
return Selection.from_topology_description(self._description)
def _ensure_opened(self) -> None:
"""Start monitors, or restart after a fork.
Hold the lock when calling this.
"""
if self._closed:
raise InvalidOperation("Cannot use MongoClient after close")
if not self._opened:
self._opened = True
self._update_servers()
# Start or restart the events publishing thread.
if self._publish_tp or self._publish_server:
self.__events_executor.open()
# Start the SRV polling thread.
if self._srv_monitor and (self.description.topology_type in SRV_POLLING_TOPOLOGIES):
self._srv_monitor.open()
if self._settings.load_balanced:
# Emit initial SDAM events for load balancer mode.
self._process_change(
ServerDescription(
self._seed_addresses[0],
Hello({"ok": 1, "serviceId": self._topology_id, "maxWireVersion": 13}),
)
)
# Ensure that the monitors are open.
for server in self._servers.values():
server.open()
def _is_stale_error(self, address: _Address, err_ctx: _ErrorContext) -> bool:
server = self._servers.get(address)
if server is None:
# Another thread removed this server from the topology.
return True
if server._pool.stale_generation(err_ctx.sock_generation, err_ctx.service_id):
# This is an outdated error from a previous pool version.
return True
# topologyVersion check, ignore error when cur_tv >= error_tv:
cur_tv = server.description.topology_version
error = err_ctx.error
error_tv = None
if error and hasattr(error, "details"):
if isinstance(error.details, dict):
error_tv = error.details.get("topologyVersion")
return _is_stale_error_topology_version(cur_tv, error_tv)
def _handle_error(self, address: _Address, err_ctx: _ErrorContext) -> None:
if self._is_stale_error(address, err_ctx):
return
server = self._servers[address]
error = err_ctx.error
service_id = err_ctx.service_id
# Ignore a handshake error if the server is behind a load balancer but
# the service ID is unknown. This indicates that the error happened
# when dialing the connection or during the MongoDB handshake, so we
# don't know the service ID to use for clearing the pool.
if self._settings.load_balanced and not service_id and not err_ctx.completed_handshake:
return
if isinstance(error, NetworkTimeout) and err_ctx.completed_handshake:
# The socket has been closed. Don't reset the server.
# Server Discovery And Monitoring Spec: "When an application
# operation fails because of any network error besides a socket
# timeout...."
return
elif isinstance(error, WriteError):
# Ignore writeErrors.
return
elif isinstance(error, (NotPrimaryError, OperationFailure)):
# As per the SDAM spec if:
# - the server sees a "not primary" error, and
# - the server is not shutting down, and
# - the server version is >= 4.2, then
# we keep the existing connection pool, but mark the server type
# as Unknown and request an immediate check of the server.
# Otherwise, we clear the connection pool, mark the server as
# Unknown and request an immediate check of the server.
if hasattr(error, "code"):
err_code = error.code
else:
# Default error code if one does not exist.
default = 10107 if isinstance(error, NotPrimaryError) else None
err_code = error.details.get("code", default) # type: ignore[union-attr]
if err_code in helpers_shared._NOT_PRIMARY_CODES:
is_shutting_down = err_code in helpers_shared._SHUTDOWN_CODES
# Mark server Unknown, clear the pool, and request check.
if not self._settings.load_balanced:
self._process_change(ServerDescription(address, error=error))
if is_shutting_down or (err_ctx.max_wire_version <= 7):
# Clear the pool.
server.reset(service_id)
server.request_check()
elif not err_ctx.completed_handshake:
# Unknown command error during the connection handshake.
if not self._settings.load_balanced:
self._process_change(ServerDescription(address, error=error))
# Clear the pool.
server.reset(service_id)
elif isinstance(error, ConnectionFailure):
# "Client MUST replace the server's description with type Unknown
# ... MUST NOT request an immediate check of the server."
if not self._settings.load_balanced:
self._process_change(ServerDescription(address, error=error))
# Clear the pool.
server.reset(service_id)
# "When a client marks a server Unknown from `Network error when
# reading or writing`_, clients MUST cancel the hello check on
# that server and close the current monitoring connection."
server._monitor.cancel_check()
def handle_error(self, address: _Address, err_ctx: _ErrorContext) -> None:
"""Handle an application error.
May reset the server to Unknown, clear the pool, and request an
immediate check depending on the error and the context.
"""
with self._lock:
self._handle_error(address, err_ctx)
def _request_check_all(self) -> None:
"""Wake all monitors. Hold the lock when calling this."""
for server in self._servers.values():
server.request_check()
def _update_servers(self) -> None:
"""Sync our Servers from TopologyDescription.server_descriptions.
Hold the lock while calling this.
"""
for address, sd in self._description.server_descriptions().items():
if address not in self._servers:
monitor = self._settings.monitor_class(
server_description=sd,
topology=self,
pool=self._create_pool_for_monitor(address),
topology_settings=self._settings,
)
weak = None
if self._publish_server and self._events is not None:
weak = weakref.ref(self._events)
server = Server(
server_description=sd,
pool=self._create_pool_for_server(address),
monitor=monitor,
topology_id=self._topology_id,
listeners=self._listeners,
events=weak,
)
self._servers[address] = server
server.open()
else:
# Cache old is_writable value.
was_writable = self._servers[address].description.is_writable
# Update server description.
self._servers[address].description = sd
# Update is_writable value of the pool, if it changed.
if was_writable != sd.is_writable:
self._servers[address].pool.update_is_writable(sd.is_writable)
for address, server in list(self._servers.items()):
if not self._description.has_server(address):
server.close()
self._servers.pop(address)
def _create_pool_for_server(self, address: _Address) -> Pool:
return self._settings.pool_class(
address, self._settings.pool_options, client_id=self._topology_id
)
def _create_pool_for_monitor(self, address: _Address) -> Pool:
options = self._settings.pool_options
# According to the Server Discovery And Monitoring Spec, monitors use
# connect_timeout for both connect_timeout and socket_timeout. The
# pool only has one socket so maxPoolSize and so on aren't needed.
monitor_pool_options = PoolOptions(
connect_timeout=options.connect_timeout,
socket_timeout=options.connect_timeout,
ssl_context=options._ssl_context,
tls_allow_invalid_hostnames=options.tls_allow_invalid_hostnames,
event_listeners=options._event_listeners,
appname=options.appname,
driver=options.driver,
pause_enabled=False,
server_api=options.server_api,
)
return self._settings.pool_class(
address, monitor_pool_options, handshake=False, client_id=self._topology_id
)
def _error_message(self, selector: Callable[[Selection], Selection]) -> str:
"""Format an error message if server selection fails.
Hold the lock when calling this.
"""
is_replica_set = self._description.topology_type in (
TOPOLOGY_TYPE.ReplicaSetWithPrimary,
TOPOLOGY_TYPE.ReplicaSetNoPrimary,
)
if is_replica_set:
server_plural = "replica set members"
elif self._description.topology_type == TOPOLOGY_TYPE.Sharded:
server_plural = "mongoses"
else:
server_plural = "servers"
if self._description.known_servers:
# We've connected, but no servers match the selector.
if selector is writable_server_selector:
if is_replica_set:
return "No primary available for writes"
else:
return "No %s available for writes" % server_plural
else:
return f'No {server_plural} match selector "{selector}"'
else:
addresses = list(self._description.server_descriptions())
servers = list(self._description.server_descriptions().values())
if not servers:
if is_replica_set:
# We removed all servers because of the wrong setName?
return 'No {} available for replica set name "{}"'.format(
server_plural,
self._settings.replica_set_name,
)
else:
return "No %s available" % server_plural
# 1 or more servers, all Unknown. Are they unknown for one reason?
error = servers[0].error
same = all(server.error == error for server in servers[1:])
if same:
if error is None:
# We're still discovering.
return "No %s found yet" % server_plural
if is_replica_set and not set(addresses).intersection(self._seed_addresses):
# We replaced our seeds with new hosts but can't reach any.
return (
"Could not reach any servers in %s. Replica set is"
" configured with internal hostnames or IPs?" % addresses
)
return str(error)
else:
return ",".join(str(server.error) for server in servers if server.error)
def __repr__(self) -> str:
msg = ""
if not self._opened:
msg = "CLOSED "
return f"<{self.__class__.__name__} {msg}{self._description!r}>"
def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], str]:
"""The properties to use for MongoClient/Topology equality checks."""
ts = self._settings
return (tuple(sorted(ts.seeds)), ts.replica_set_name, ts.fqdn, ts.srv_service_name)
def __eq__(self, other: object) -> bool:
if isinstance(other, self.__class__):
return self.eq_props() == other.eq_props()
return NotImplemented
def __hash__(self) -> int:
return hash(self.eq_props())
class _ErrorContext:
"""An error with context for SDAM error handling."""
def __init__(
self,
error: BaseException,
max_wire_version: int,
sock_generation: int,
completed_handshake: bool,
service_id: Optional[ObjectId],
):
self.error = error
self.max_wire_version = max_wire_version
self.sock_generation = sock_generation
self.completed_handshake = completed_handshake
self.service_id = service_id
def _is_stale_error_topology_version(
current_tv: Optional[Mapping[str, Any]], error_tv: Optional[Mapping[str, Any]]
) -> bool:
"""Return True if the error's topologyVersion is <= current."""
if current_tv is None or error_tv is None:
return False
if current_tv["processId"] != error_tv["processId"]:
return False
return current_tv["counter"] >= error_tv["counter"]
def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDescription) -> bool:
"""Return True if the new topologyVersion is < current."""
current_tv, new_tv = current_sd.topology_version, new_sd.topology_version
if current_tv is None or new_tv is None:
return False
if current_tv["processId"] != new_tv["processId"]:
return False
return current_tv["counter"] > new_tv["counter"]
def _filter_servers(
candidates: list[Server], deprioritized_servers: Optional[list[Server]] = None
) -> list[Server]:
"""Filter out deprioritized servers from a list of server candidates."""
if not deprioritized_servers:
return candidates
filtered = [server for server in candidates if server not in deprioritized_servers]
# If not possible to pick a prioritized server, return the original list
return filtered or candidates