474 lines
19 KiB
Python
474 lines
19 KiB
Python
# Copyright 2019-present MongoDB, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
|
# may not use this file except in compliance with the License. You
|
|
# may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
# implied. See the License for the specific language governing
|
|
# permissions and limitations under the License.
|
|
|
|
"""A CPython compatible SSLContext implementation wrapping PyOpenSSL's
|
|
context.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import socket as _socket
|
|
import ssl as _stdlibssl
|
|
import sys as _sys
|
|
import time as _time
|
|
from errno import EINTR as _EINTR
|
|
from ipaddress import ip_address as _ip_address
|
|
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
|
|
|
|
import cryptography.x509 as x509
|
|
import service_identity
|
|
from OpenSSL import SSL as _SSL
|
|
from OpenSSL import crypto as _crypto
|
|
|
|
from pymongo.errors import ConfigurationError as _ConfigurationError
|
|
from pymongo.errors import _CertificateError # type:ignore[attr-defined]
|
|
from pymongo.ocsp_cache import _OCSPCache
|
|
from pymongo.ocsp_support import _load_trusted_ca_certs, _ocsp_callback
|
|
from pymongo.socket_checker import SocketChecker as _SocketChecker
|
|
from pymongo.socket_checker import _errno_from_exception
|
|
from pymongo.write_concern import validate_boolean
|
|
|
|
if TYPE_CHECKING:
|
|
from ssl import VerifyMode
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
try:
|
|
import certifi
|
|
|
|
_HAVE_CERTIFI = True
|
|
except ImportError:
|
|
_HAVE_CERTIFI = False
|
|
|
|
PROTOCOL_SSLv23 = _SSL.SSLv23_METHOD
|
|
# Always available
|
|
OP_NO_SSLv2 = _SSL.OP_NO_SSLv2
|
|
OP_NO_SSLv3 = _SSL.OP_NO_SSLv3
|
|
OP_NO_COMPRESSION = _SSL.OP_NO_COMPRESSION
|
|
# This isn't currently documented for PyOpenSSL
|
|
OP_NO_RENEGOTIATION = getattr(_SSL, "OP_NO_RENEGOTIATION", 0)
|
|
|
|
# Always available
|
|
HAS_SNI = True
|
|
IS_PYOPENSSL = True
|
|
|
|
# Base Exception class
|
|
SSLError = _SSL.Error
|
|
|
|
# https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L2995-L3002
|
|
_VERIFY_MAP = {
|
|
_stdlibssl.CERT_NONE: _SSL.VERIFY_NONE,
|
|
_stdlibssl.CERT_OPTIONAL: _SSL.VERIFY_PEER,
|
|
_stdlibssl.CERT_REQUIRED: _SSL.VERIFY_PEER | _SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
|
|
}
|
|
|
|
_REVERSE_VERIFY_MAP = {value: key for key, value in _VERIFY_MAP.items()}
|
|
|
|
|
|
# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are
|
|
# not permitted for SNI hostname.
|
|
def _is_ip_address(address: Any) -> bool:
|
|
try:
|
|
_ip_address(address)
|
|
return True
|
|
except (ValueError, UnicodeError):
|
|
return False
|
|
|
|
|
|
# According to the docs for socket.send it can raise
|
|
# WantX509LookupError and should be retried.
|
|
BLOCKING_IO_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError)
|
|
BLOCKING_IO_READ_ERROR = _SSL.WantReadError
|
|
BLOCKING_IO_WRITE_ERROR = _SSL.WantWriteError
|
|
BLOCKING_IO_LOOKUP_ERROR = _SSL.WantX509LookupError
|
|
|
|
|
|
def _ragged_eof(exc: BaseException) -> bool:
|
|
"""Return True if the OpenSSL.SSL.SysCallError is a ragged EOF."""
|
|
return exc.args == (-1, "Unexpected EOF")
|
|
|
|
|
|
# https://github.com/pyca/pyopenssl/issues/168
|
|
# https://github.com/pyca/pyopenssl/issues/176
|
|
# https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets
|
|
class _sslConn(_SSL.Connection):
|
|
def __init__(
|
|
self, ctx: _SSL.Context, sock: Optional[_socket.socket], suppress_ragged_eofs: bool
|
|
):
|
|
self.socket_checker = _SocketChecker()
|
|
self.suppress_ragged_eofs = suppress_ragged_eofs
|
|
super().__init__(ctx, sock)
|
|
|
|
def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T:
|
|
timeout = self.gettimeout()
|
|
if timeout:
|
|
start = _time.monotonic()
|
|
while True:
|
|
try:
|
|
return call(*args, **kwargs)
|
|
except BLOCKING_IO_ERRORS as exc:
|
|
# Check for closed socket.
|
|
if self.fileno() == -1:
|
|
if timeout and _time.monotonic() - start > timeout:
|
|
raise _socket.timeout("timed out") from None
|
|
raise SSLError("Underlying socket has been closed") from None
|
|
if isinstance(exc, _SSL.WantReadError):
|
|
want_read = True
|
|
want_write = False
|
|
elif isinstance(exc, _SSL.WantWriteError):
|
|
want_read = False
|
|
want_write = True
|
|
else:
|
|
want_read = True
|
|
want_write = True
|
|
self.socket_checker.select(self, want_read, want_write, timeout)
|
|
if timeout and _time.monotonic() - start > timeout:
|
|
raise _socket.timeout("timed out") from None
|
|
continue
|
|
|
|
def do_handshake(self, *args: Any, **kwargs: Any) -> None:
|
|
return self._call(super().do_handshake, *args, **kwargs)
|
|
|
|
def recv(self, *args: Any, **kwargs: Any) -> bytes:
|
|
try:
|
|
return self._call(super().recv, *args, **kwargs)
|
|
except _SSL.SysCallError as exc:
|
|
# Suppress ragged EOFs to match the stdlib.
|
|
if self.suppress_ragged_eofs and _ragged_eof(exc):
|
|
return b""
|
|
raise
|
|
|
|
def recv_into(self, *args: Any, **kwargs: Any) -> int:
|
|
try:
|
|
return self._call(super().recv_into, *args, **kwargs)
|
|
except _SSL.SysCallError as exc:
|
|
# Suppress ragged EOFs to match the stdlib.
|
|
if self.suppress_ragged_eofs and _ragged_eof(exc):
|
|
return 0
|
|
raise
|
|
|
|
def sendall(self, buf: bytes, flags: int = 0) -> None: # type: ignore[override]
|
|
view = memoryview(buf)
|
|
total_length = len(buf)
|
|
total_sent = 0
|
|
while total_sent < total_length:
|
|
try:
|
|
sent = self._call(super().send, view[total_sent:], flags)
|
|
# XXX: It's not clear if this can actually happen. PyOpenSSL
|
|
# doesn't appear to have any interrupt handling, nor any interrupt
|
|
# errors for OpenSSL connections.
|
|
except OSError as exc:
|
|
if _errno_from_exception(exc) == _EINTR:
|
|
continue
|
|
raise
|
|
# https://github.com/pyca/pyopenssl/blob/19.1.0/src/OpenSSL/SSL.py#L1756
|
|
# https://www.openssl.org/docs/man1.0.2/man3/SSL_write.html
|
|
if sent <= 0:
|
|
raise OSError("connection closed")
|
|
total_sent += sent
|
|
|
|
|
|
class _CallbackData:
|
|
"""Data class which is passed to the OCSP callback."""
|
|
|
|
def __init__(self) -> None:
|
|
self.trusted_ca_certs: Optional[list[x509.Certificate]] = None
|
|
self.check_ocsp_endpoint: Optional[bool] = None
|
|
self.ocsp_response_cache = _OCSPCache()
|
|
|
|
|
|
class SSLContext:
|
|
"""A CPython compatible SSLContext implementation wrapping PyOpenSSL's
|
|
context.
|
|
"""
|
|
|
|
__slots__ = ("_protocol", "_ctx", "_callback_data", "_check_hostname")
|
|
|
|
def __init__(self, protocol: int):
|
|
self._protocol = protocol
|
|
self._ctx = _SSL.Context(self._protocol)
|
|
self._callback_data = _CallbackData()
|
|
self._check_hostname = True
|
|
# OCSP
|
|
# XXX: Find a better place to do this someday, since this is client
|
|
# side configuration and wrap_socket tries to support both client and
|
|
# server side sockets.
|
|
self._callback_data.check_ocsp_endpoint = True
|
|
self._ctx.set_ocsp_client_callback(callback=_ocsp_callback, data=self._callback_data)
|
|
|
|
@property
|
|
def protocol(self) -> int:
|
|
"""The protocol version chosen when constructing the context.
|
|
This attribute is read-only.
|
|
"""
|
|
return self._protocol
|
|
|
|
def __get_verify_mode(self) -> VerifyMode:
|
|
"""Whether to try to verify other peers' certificates and how to
|
|
behave if verification fails. This attribute must be one of
|
|
ssl.CERT_NONE, ssl.CERT_OPTIONAL or ssl.CERT_REQUIRED.
|
|
"""
|
|
return _REVERSE_VERIFY_MAP[self._ctx.get_verify_mode()]
|
|
|
|
def __set_verify_mode(self, value: VerifyMode) -> None:
|
|
"""Setter for verify_mode."""
|
|
|
|
def _cb(
|
|
_connobj: _SSL.Connection,
|
|
_x509obj: _crypto.X509,
|
|
_errnum: int,
|
|
_errdepth: int,
|
|
retcode: int,
|
|
) -> bool:
|
|
# It seems we don't need to do anything here. Twisted doesn't,
|
|
# and OpenSSL's SSL_CTX_set_verify let's you pass NULL
|
|
# for the callback option. It's weird that PyOpenSSL requires
|
|
# this.
|
|
# This is optional in pyopenssl >= 20 and can be removed once minimum
|
|
# supported version is bumped
|
|
# See: pyopenssl.org/en/latest/changelog.html#id47
|
|
return bool(retcode)
|
|
|
|
self._ctx.set_verify(_VERIFY_MAP[value], _cb)
|
|
|
|
verify_mode = property(__get_verify_mode, __set_verify_mode)
|
|
|
|
def __get_check_hostname(self) -> bool:
|
|
return self._check_hostname
|
|
|
|
def __set_check_hostname(self, value: Any) -> None:
|
|
validate_boolean("check_hostname", value)
|
|
self._check_hostname = value
|
|
|
|
check_hostname = property(__get_check_hostname, __set_check_hostname)
|
|
|
|
def __get_check_ocsp_endpoint(self) -> Optional[bool]:
|
|
return self._callback_data.check_ocsp_endpoint
|
|
|
|
def __set_check_ocsp_endpoint(self, value: bool) -> None:
|
|
validate_boolean("check_ocsp", value)
|
|
self._callback_data.check_ocsp_endpoint = value
|
|
|
|
check_ocsp_endpoint = property(__get_check_ocsp_endpoint, __set_check_ocsp_endpoint)
|
|
|
|
def __get_options(self) -> None:
|
|
# Calling set_options adds the option to the existing bitmask and
|
|
# returns the new bitmask.
|
|
# https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_options
|
|
return self._ctx.set_options(0)
|
|
|
|
def __set_options(self, value: int) -> None:
|
|
# Explicitly convert to int, since newer CPython versions
|
|
# use enum.IntFlag for options. The values are the same
|
|
# regardless of implementation.
|
|
self._ctx.set_options(int(value))
|
|
|
|
options = property(__get_options, __set_options)
|
|
|
|
def load_cert_chain(
|
|
self,
|
|
certfile: Union[str, bytes],
|
|
keyfile: Union[str, bytes, None] = None,
|
|
password: Optional[str] = None,
|
|
) -> None:
|
|
"""Load a private key and the corresponding certificate. The certfile
|
|
string must be the path to a single file in PEM format containing the
|
|
certificate as well as any number of CA certificates needed to
|
|
establish the certificate's authenticity. The keyfile string, if
|
|
present, must point to a file containing the private key. Otherwise
|
|
the private key will be taken from certfile as well.
|
|
"""
|
|
# Match CPython behavior
|
|
# https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L3930-L3971
|
|
# Password callback MUST be set first or it will be ignored.
|
|
if password:
|
|
|
|
def _pwcb(_max_length: int, _prompt_twice: bool, _user_data: Optional[bytes]) -> bytes:
|
|
# XXX:We could check the password length against what OpenSSL
|
|
# tells us is the max, but we can't raise an exception, so...
|
|
# warn?
|
|
assert password is not None
|
|
return password.encode("utf-8")
|
|
|
|
self._ctx.set_passwd_cb(_pwcb)
|
|
self._ctx.use_certificate_chain_file(certfile)
|
|
self._ctx.use_privatekey_file(keyfile or certfile)
|
|
self._ctx.check_privatekey()
|
|
|
|
def load_verify_locations(
|
|
self, cafile: Optional[str] = None, capath: Optional[str] = None
|
|
) -> None:
|
|
"""Load a set of "certification authority"(CA) certificates used to
|
|
validate other peers' certificates when `~verify_mode` is other than
|
|
ssl.CERT_NONE.
|
|
"""
|
|
self._ctx.load_verify_locations(cafile, capath)
|
|
# Manually load the CA certs when get_verified_chain is not available (pyopenssl<20).
|
|
if not hasattr(_SSL.Connection, "get_verified_chain"):
|
|
assert cafile is not None
|
|
self._callback_data.trusted_ca_certs = _load_trusted_ca_certs(cafile)
|
|
|
|
def _load_certifi(self) -> None:
|
|
"""Attempt to load CA certs from certifi."""
|
|
if _HAVE_CERTIFI:
|
|
self.load_verify_locations(certifi.where())
|
|
else:
|
|
raise _ConfigurationError(
|
|
"tlsAllowInvalidCertificates is False but no system "
|
|
"CA certificates could be loaded. Please install the "
|
|
"certifi package, or provide a path to a CA file using "
|
|
"the tlsCAFile option"
|
|
)
|
|
|
|
def _load_wincerts(self, store: str) -> None:
|
|
"""Attempt to load CA certs from Windows trust store."""
|
|
cert_store = self._ctx.get_cert_store()
|
|
assert cert_store is not None
|
|
oid = _stdlibssl.Purpose.SERVER_AUTH.oid
|
|
|
|
for cert, encoding, trust in _stdlibssl.enum_certificates(store): # type: ignore
|
|
if encoding == "x509_asn":
|
|
if trust is True or oid in trust:
|
|
cert_store.add_cert(
|
|
_crypto.X509.from_cryptography(x509.load_der_x509_certificate(cert))
|
|
)
|
|
|
|
def load_default_certs(self) -> None:
|
|
"""A PyOpenSSL version of load_default_certs from CPython."""
|
|
# PyOpenSSL is incapable of loading CA certs from Windows, and mostly
|
|
# incapable on macOS.
|
|
# https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_default_verify_paths
|
|
if _sys.platform == "win32":
|
|
try:
|
|
for storename in ("CA", "ROOT"):
|
|
self._load_wincerts(storename)
|
|
except PermissionError:
|
|
# Fall back to certifi
|
|
self._load_certifi()
|
|
elif _sys.platform == "darwin":
|
|
self._load_certifi()
|
|
self._ctx.set_default_verify_paths()
|
|
|
|
def set_default_verify_paths(self) -> None:
|
|
"""Specify that the platform provided CA certificates are to be used
|
|
for verification purposes.
|
|
"""
|
|
# Note: See PyOpenSSL's docs for limitations, which are similar
|
|
# but not that same as CPython's.
|
|
self._ctx.set_default_verify_paths()
|
|
|
|
async def a_wrap_socket(
|
|
self,
|
|
sock: _socket.socket,
|
|
server_side: bool = False,
|
|
do_handshake_on_connect: bool = True,
|
|
suppress_ragged_eofs: bool = True,
|
|
server_hostname: Optional[str] = None,
|
|
session: Optional[_SSL.Session] = None,
|
|
) -> _sslConn:
|
|
"""Wrap an existing Python socket connection and return a TLS socket
|
|
object.
|
|
"""
|
|
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs)
|
|
loop = asyncio.get_running_loop()
|
|
if session:
|
|
ssl_conn.set_session(session)
|
|
if server_side is True:
|
|
ssl_conn.set_accept_state()
|
|
else:
|
|
# SNI
|
|
if server_hostname and not _is_ip_address(server_hostname):
|
|
# XXX: Do this in a callback registered with
|
|
# SSLContext.set_info_callback? See Twisted for an example.
|
|
ssl_conn.set_tlsext_host_name(server_hostname.encode("idna"))
|
|
if self.verify_mode != _stdlibssl.CERT_NONE:
|
|
# Request a stapled OCSP response.
|
|
await loop.run_in_executor(None, ssl_conn.request_ocsp)
|
|
ssl_conn.set_connect_state()
|
|
# If this wasn't true the caller of wrap_socket would call
|
|
# do_handshake()
|
|
if do_handshake_on_connect:
|
|
# XXX: If we do hostname checking in a callback we can get rid
|
|
# of this call to do_handshake() since the handshake
|
|
# will happen automatically later.
|
|
await loop.run_in_executor(None, ssl_conn.do_handshake)
|
|
# XXX: Do this in a callback registered with
|
|
# SSLContext.set_info_callback? See Twisted for an example.
|
|
if self.check_hostname and server_hostname is not None:
|
|
from service_identity import pyopenssl
|
|
|
|
try:
|
|
if _is_ip_address(server_hostname):
|
|
pyopenssl.verify_ip_address(ssl_conn, server_hostname)
|
|
else:
|
|
pyopenssl.verify_hostname(ssl_conn, server_hostname)
|
|
except ( # type:ignore[misc]
|
|
service_identity.SICertificateError,
|
|
service_identity.SIVerificationError,
|
|
) as exc:
|
|
raise _CertificateError(str(exc)) from None
|
|
return ssl_conn
|
|
|
|
def wrap_socket(
|
|
self,
|
|
sock: _socket.socket,
|
|
server_side: bool = False,
|
|
do_handshake_on_connect: bool = True,
|
|
suppress_ragged_eofs: bool = True,
|
|
server_hostname: Optional[str] = None,
|
|
session: Optional[_SSL.Session] = None,
|
|
) -> _sslConn:
|
|
"""Wrap an existing Python socket connection and return a TLS socket
|
|
object.
|
|
"""
|
|
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs)
|
|
if session:
|
|
ssl_conn.set_session(session)
|
|
if server_side is True:
|
|
ssl_conn.set_accept_state()
|
|
else:
|
|
# SNI
|
|
if server_hostname and not _is_ip_address(server_hostname):
|
|
# XXX: Do this in a callback registered with
|
|
# SSLContext.set_info_callback? See Twisted for an example.
|
|
ssl_conn.set_tlsext_host_name(server_hostname.encode("idna"))
|
|
if self.verify_mode != _stdlibssl.CERT_NONE:
|
|
# Request a stapled OCSP response.
|
|
ssl_conn.request_ocsp()
|
|
ssl_conn.set_connect_state()
|
|
# If this wasn't true the caller of wrap_socket would call
|
|
# do_handshake()
|
|
if do_handshake_on_connect:
|
|
# XXX: If we do hostname checking in a callback we can get rid
|
|
# of this call to do_handshake() since the handshake
|
|
# will happen automatically later.
|
|
ssl_conn.do_handshake()
|
|
# XXX: Do this in a callback registered with
|
|
# SSLContext.set_info_callback? See Twisted for an example.
|
|
if self.check_hostname and server_hostname is not None:
|
|
from service_identity import pyopenssl
|
|
|
|
try:
|
|
if _is_ip_address(server_hostname):
|
|
pyopenssl.verify_ip_address(ssl_conn, server_hostname)
|
|
else:
|
|
pyopenssl.verify_hostname(ssl_conn, server_hostname)
|
|
except ( # type:ignore[misc]
|
|
service_identity.SICertificateError,
|
|
service_identity.SIVerificationError,
|
|
) as exc:
|
|
raise _CertificateError(str(exc)) from None
|
|
return ssl_conn
|