queues/venv/lib/python3.11/site-packages/pymongo/_csot.py
Egor Matveev 6c6a549aff
All checks were successful
Deploy Prod / Build (pull_request) Successful in 9s
Deploy Prod / Push (pull_request) Successful in 12s
Deploy Prod / Deploy prod (pull_request) Successful in 10s
fix
2024-12-28 22:48:16 +03:00

163 lines
5.0 KiB
Python

# Copyright 2022-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Internal helpers for CSOT."""
from __future__ import annotations
import functools
import inspect
import time
from collections import deque
from contextlib import AbstractContextManager
from contextvars import ContextVar, Token
from typing import TYPE_CHECKING, Any, Callable, Deque, MutableMapping, Optional, TypeVar, cast
if TYPE_CHECKING:
from pymongo.write_concern import WriteConcern
TIMEOUT: ContextVar[Optional[float]] = ContextVar("TIMEOUT", default=None)
RTT: ContextVar[float] = ContextVar("RTT", default=0.0)
DEADLINE: ContextVar[float] = ContextVar("DEADLINE", default=float("inf"))
def get_timeout() -> Optional[float]:
return TIMEOUT.get(None)
def get_rtt() -> float:
return RTT.get()
def get_deadline() -> float:
return DEADLINE.get()
def set_rtt(rtt: float) -> None:
RTT.set(rtt)
def remaining() -> Optional[float]:
if not get_timeout():
return None
return DEADLINE.get() - time.monotonic()
def clamp_remaining(max_timeout: float) -> float:
"""Return the remaining timeout clamped to a max value."""
timeout = remaining()
if timeout is None:
return max_timeout
return min(timeout, max_timeout)
class _TimeoutContext(AbstractContextManager):
"""Internal timeout context manager.
Use :func:`pymongo.timeout` instead::
with pymongo.timeout(0.5):
client.test.test.insert_one({})
"""
def __init__(self, timeout: Optional[float]):
self._timeout = timeout
self._tokens: Optional[tuple[Token[Optional[float]], Token[float], Token[float]]] = None
def __enter__(self) -> _TimeoutContext:
timeout_token = TIMEOUT.set(self._timeout)
prev_deadline = DEADLINE.get()
next_deadline = time.monotonic() + self._timeout if self._timeout else float("inf")
deadline_token = DEADLINE.set(min(prev_deadline, next_deadline))
rtt_token = RTT.set(0.0)
self._tokens = (timeout_token, deadline_token, rtt_token)
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._tokens:
timeout_token, deadline_token, rtt_token = self._tokens
TIMEOUT.reset(timeout_token)
DEADLINE.reset(deadline_token)
RTT.reset(rtt_token)
# See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories
F = TypeVar("F", bound=Callable[..., Any])
def apply(func: F) -> F:
"""Apply the client's timeoutMS to this operation. Can wrap both asynchronous and synchronous methods"""
if inspect.iscoroutinefunction(func):
@functools.wraps(func)
async def csot_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
if get_timeout() is None:
timeout = self._timeout
if timeout is not None:
with _TimeoutContext(timeout):
return await func(self, *args, **kwargs)
return await func(self, *args, **kwargs)
else:
@functools.wraps(func)
def csot_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
if get_timeout() is None:
timeout = self._timeout
if timeout is not None:
with _TimeoutContext(timeout):
return func(self, *args, **kwargs)
return func(self, *args, **kwargs)
return cast(F, csot_wrapper)
def apply_write_concern(
cmd: MutableMapping[str, Any], write_concern: Optional[WriteConcern]
) -> None:
"""Apply the given write concern to a command."""
if not write_concern or write_concern.is_server_default:
return
wc = write_concern.document
if get_timeout() is not None:
wc.pop("wtimeout", None)
if wc:
cmd["writeConcern"] = wc
_MAX_RTT_SAMPLES: int = 10
_MIN_RTT_SAMPLES: int = 2
class MovingMinimum:
"""Tracks a minimum RTT within the last 10 RTT samples."""
samples: Deque[float]
def __init__(self) -> None:
self.samples = deque(maxlen=_MAX_RTT_SAMPLES)
def add_sample(self, sample: float) -> None:
if sample < 0:
raise ValueError(f"duration cannot be negative {sample}")
self.samples.append(sample)
def get(self) -> float:
"""Get the min, or 0.0 if there aren't enough samples yet."""
if len(self.samples) >= _MIN_RTT_SAMPLES:
return min(self.samples)
return 0.0
def reset(self) -> None:
self.samples.clear()