From 21ee98bfc92682d45c0271ce99923b4be46e79d8 Mon Sep 17 00:00:00 2001 From: Semen Syrovatskiy Date: Thu, 15 Feb 2024 11:35:09 +0300 Subject: [PATCH] feat(client): improve error message for http timeouts --- src/prisma/_async_http.py | 16 +++++++++ src/prisma/_constants.py | 11 ++++++- src/prisma/_sync_http.py | 16 +++++++++ src/prisma/_types.py | 2 ++ src/prisma/errors.py | 9 +++++ src/prisma/http_abstract.py | 10 ++---- src/prisma/utils.py | 59 +++++++++++++++++++++++++++++++-- tests/test_client.py | 4 +-- tests/test_http.py | 26 ++++++++++++++- tests/test_utils.py | 66 +++++++++++++++++++++++++++++++++++++ 10 files changed, 206 insertions(+), 13 deletions(-) create mode 100644 tests/test_utils.py diff --git a/src/prisma/_async_http.py b/src/prisma/_async_http.py index 615b01bf8..812d29b95 100644 --- a/src/prisma/_async_http.py +++ b/src/prisma/_async_http.py @@ -4,15 +4,28 @@ import httpx +from .utils import ExcConverter from ._types import Method +from .errors import HTTPClientTimeoutError from .http_abstract import AbstractHTTP, AbstractResponse __all__ = ('HTTP', 'AsyncHTTP', 'Response', 'client') +convert_exc = ExcConverter( + { + httpx.ConnectTimeout: HTTPClientTimeoutError, + httpx.ReadTimeout: HTTPClientTimeoutError, + httpx.WriteTimeout: HTTPClientTimeoutError, + httpx.PoolTimeout: HTTPClientTimeoutError, + } +) + + class AsyncHTTP(AbstractHTTP[httpx.AsyncClient, httpx.Response]): session: httpx.AsyncClient + @convert_exc @override async def download(self, url: str, dest: str) -> None: async with self.session.stream('GET', url, timeout=None) as resp: @@ -21,14 +34,17 @@ async def download(self, url: str, dest: str) -> None: async for chunk in resp.aiter_bytes(): fd.write(chunk) + @convert_exc @override async def request(self, method: Method, url: str, **kwargs: Any) -> 'Response': return Response(await self.session.request(method, url, **kwargs)) + @convert_exc @override def open(self) -> None: self.session = httpx.AsyncClient(**self.session_kwargs) + @convert_exc @override async def close(self) -> None: if self.should_close(): diff --git a/src/prisma/_constants.py b/src/prisma/_constants.py index d13005c33..fa218ffcf 100644 --- a/src/prisma/_constants.py +++ b/src/prisma/_constants.py @@ -1,10 +1,19 @@ -from typing import Dict +from typing import Any, Dict from datetime import timedelta +import httpx + DEFAULT_CONNECT_TIMEOUT: timedelta = timedelta(seconds=10) DEFAULT_TX_MAX_WAIT: timedelta = timedelta(milliseconds=2000) DEFAULT_TX_TIMEOUT: timedelta = timedelta(milliseconds=5000) +DEFAULT_HTTP_LIMITS: httpx.Limits = httpx.Limits(max_connections=1000) +DEFAULT_HTTP_TIMEOUT: httpx.Timeout = httpx.Timeout(30) +DEFAULT_HTTP_CONFIG: Dict[str, Any] = { + 'limits': DEFAULT_HTTP_LIMITS, + 'timeout': DEFAULT_HTTP_TIMEOUT, +} + # key aliases to transform query arguments to make them more pythonic QUERY_BUILDER_ALIASES: Dict[str, str] = { 'startswith': 'startsWith', diff --git a/src/prisma/_sync_http.py b/src/prisma/_sync_http.py index 9228fb2b0..35563ce6c 100644 --- a/src/prisma/_sync_http.py +++ b/src/prisma/_sync_http.py @@ -3,15 +3,28 @@ import httpx +from .utils import ExcConverter from ._types import Method +from .errors import HTTPClientTimeoutError from .http_abstract import AbstractHTTP, AbstractResponse __all__ = ('HTTP', 'SyncHTTP', 'Response', 'client') +convert_exc = ExcConverter( + { + httpx.ConnectTimeout: HTTPClientTimeoutError, + httpx.ReadTimeout: HTTPClientTimeoutError, + httpx.WriteTimeout: HTTPClientTimeoutError, + httpx.PoolTimeout: HTTPClientTimeoutError, + } +) + + class SyncHTTP(AbstractHTTP[httpx.Client, httpx.Response]): session: httpx.Client + @convert_exc @override def download(self, url: str, dest: str) -> None: with self.session.stream('GET', url, timeout=None) as resp: @@ -20,14 +33,17 @@ def download(self, url: str, dest: str) -> None: for chunk in resp.iter_bytes(): fd.write(chunk) + @convert_exc @override def request(self, method: Method, url: str, **kwargs: Any) -> 'Response': return Response(self.session.request(method, url, **kwargs)) + @convert_exc @override def open(self) -> None: self.session = httpx.Client(**self.session_kwargs) + @convert_exc @override def close(self) -> None: if self.should_close(): diff --git a/src/prisma/_types.py b/src/prisma/_types.py index b5fbb7a88..4c43e2d25 100644 --- a/src/prisma/_types.py +++ b/src/prisma/_types.py @@ -23,6 +23,8 @@ FuncType = Callable[..., object] CoroType = Callable[..., Coroutine[Any, Any, object]] +ExcMapping = Mapping[Type[BaseException], Type[BaseException]] + @runtime_checkable class InheritsGeneric(Protocol): diff --git a/src/prisma/errors.py b/src/prisma/errors.py index e2aca4af2..a9d82e112 100644 --- a/src/prisma/errors.py +++ b/src/prisma/errors.py @@ -12,6 +12,7 @@ 'TableNotFoundError', 'RecordNotFoundError', 'HTTPClientClosedError', + 'HTTPClientTimeoutError', 'ClientNotConnectedError', 'PrismaWarning', 'UnsupportedSubclassWarning', @@ -44,6 +45,14 @@ def __init__(self) -> None: super().__init__('Cannot make a request from a closed client.') +class HTTPClientTimeoutError(PrismaError): + def __init__(self) -> None: + super().__init__( + 'HTTP operation has timed out.\n' + 'The default timeout is 30 seconds. Maybe you should increase it: prisma.Prisma(http_config={"timeout": httpx.Timeout(30)})' + ) + + class UnsupportedDatabaseError(PrismaError): context: str database: str diff --git a/src/prisma/http_abstract.py b/src/prisma/http_abstract.py index e66ed5713..7ad6584e3 100644 --- a/src/prisma/http_abstract.py +++ b/src/prisma/http_abstract.py @@ -12,22 +12,18 @@ ) from typing_extensions import override -from httpx import Limits, Headers, Timeout +from httpx import Headers from .utils import _NoneType from ._types import Method from .errors import HTTPClientClosedError +from ._constants import DEFAULT_HTTP_CONFIG Session = TypeVar('Session') Response = TypeVar('Response') ReturnType = TypeVar('ReturnType') MaybeCoroutine = Union[Coroutine[Any, Any, ReturnType], ReturnType] -DEFAULT_CONFIG: Dict[str, Any] = { - 'limits': Limits(max_connections=1000), - 'timeout': Timeout(30), -} - class AbstractHTTP(ABC, Generic[Session, Response]): session_kwargs: Dict[str, Any] @@ -45,7 +41,7 @@ def __init__(self, **kwargs: Any) -> None: # Session = open self._session: Optional[Union[Session, Type[_NoneType]]] = _NoneType self.session_kwargs = { - **DEFAULT_CONFIG, + **DEFAULT_HTTP_CONFIG, **kwargs, } diff --git a/src/prisma/utils.py b/src/prisma/utils.py index 69265291f..e853470ca 100644 --- a/src/prisma/utils.py +++ b/src/prisma/utils.py @@ -6,11 +6,13 @@ import inspect import logging import warnings +import functools import contextlib -from typing import TYPE_CHECKING, Any, Dict, Union, TypeVar, Iterator, NoReturn, Coroutine +from types import TracebackType +from typing import TYPE_CHECKING, Any, Dict, Type, Union, TypeVar, Callable, Iterator, NoReturn, Optional, Coroutine from importlib.util import find_spec -from ._types import CoroType, FuncType, TypeGuard +from ._types import CoroType, FuncType, TypeGuard, ExcMapping if TYPE_CHECKING: from typing_extensions import TypeGuard @@ -139,3 +141,56 @@ def make_optional(value: _T) -> _T | None: def is_dict(obj: object) -> TypeGuard[dict[object, object]]: return isinstance(obj, dict) + + +# TODO: improve typing +class MaybeAsyncContextDecorator(contextlib.ContextDecorator): + """`ContextDecorator` compatible with sync/async functions.""" + + def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]: # type: ignore[override] + @functools.wraps(func) + async def async_inner(*args: Any, **kwargs: Any) -> object: + async with self._recreate_cm(): # type: ignore[attr-defined] + return await func(*args, **kwargs) + + @functools.wraps(func) + def sync_inner(*args: Any, **kwargs: Any) -> object: + with self._recreate_cm(): # type: ignore[attr-defined] + return func(*args, **kwargs) + + if is_coroutine(func): + return async_inner + else: + return sync_inner + + +class ExcConverter(MaybeAsyncContextDecorator): + """`MaybeAsyncContextDecorator` to convert exceptions.""" + + def __init__(self, exc_mapping: ExcMapping) -> None: + self._exc_mapping = exc_mapping + + def __enter__(self) -> 'ExcConverter': + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if exc is not None and exc_type is not None: + target_exc_type = self._exc_mapping.get(exc_type) + if target_exc_type is not None: + raise target_exc_type() from exc + + async def __aenter__(self) -> 'ExcConverter': + return self.__enter__() + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.__exit__(exc_type, exc, exc_tb) diff --git a/tests/test_client.py b/tests/test_client.py index 5f4406ae9..b920c272e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -10,10 +10,10 @@ from prisma import ENGINE_TYPE, SCHEMA_PATH, Prisma, errors, get_client from prisma.types import HttpConfig from prisma.testing import reset_client +from prisma._constants import DEFAULT_HTTP_CONFIG from prisma.cli.prisma import run from prisma.engine.http import HTTPEngine from prisma.engine.errors import AlreadyConnectedError -from prisma.http_abstract import DEFAULT_CONFIG from .utils import Testdir, patch_method @@ -140,7 +140,7 @@ async def _test(config: HttpConfig) -> None: captured = getter() assert captured is not None - assert captured == ((), {**DEFAULT_CONFIG, **config}) + assert captured == ((), {**DEFAULT_HTTP_CONFIG, **config}) await _test({'timeout': 1}) await _test({'timeout': httpx.Timeout(5, connect=10, read=30)}) diff --git a/tests/test_http.py b/tests/test_http.py index 3fc6ff58e..284927c7e 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -2,11 +2,12 @@ import httpx import pytest +from pytest_mock import MockerFixture from prisma.http import HTTP from prisma.utils import _NoneType from prisma._types import Literal -from prisma.errors import HTTPClientClosedError +from prisma.errors import HTTPClientClosedError, HTTPClientTimeoutError from .utils import patch_method @@ -81,3 +82,26 @@ async def test_httpx_default_config(monkeypatch: 'MonkeyPatch') -> None: 'timeout': httpx.Timeout(30), }, ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'httpx_error', + [ + httpx.ConnectTimeout(''), + httpx.ReadTimeout(''), + httpx.WriteTimeout(''), + httpx.PoolTimeout(''), + ], +) +async def test_http_timeout_error(httpx_error: BaseException, mocker: MockerFixture) -> None: + """Ensure that `httpx.TimeoutException` is converted to `prisma.errors.HTTPClientTimeoutError`.""" + mocker.patch('httpx.AsyncClient.request', side_effect=httpx_error) + + http = HTTP() + http.open() + + with pytest.raises(HTTPClientTimeoutError) as exc_info: + await http.request('GET', '/') + + assert exc_info.value.__cause__ == httpx_error diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..777852d0f --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,66 @@ +import asyncio +from typing import Type, NoReturn + +import pytest + +from prisma.utils import ExcConverter + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ('convert_exc', 'raised_exc_type', 'expected_exc_type', 'should_be_converted'), + [ + pytest.param(ExcConverter({ValueError: ImportError}), ValueError, ImportError, True, id='should convert'), + pytest.param( + ExcConverter({ValueError: ImportError}), RuntimeError, RuntimeError, False, id='should not convert' + ), + ], +) +async def test_exc_converter( + convert_exc: ExcConverter, + raised_exc_type: Type[BaseException], + expected_exc_type: Type[BaseException], + should_be_converted: bool, +) -> None: + """Ensure that `prisma.utils.ExcConverter` works as expected.""" + + # Test sync context manager + with pytest.raises(expected_exc_type) as exc_info_1: + with convert_exc: + raise raised_exc_type() + + # Test async context manager + with pytest.raises(expected_exc_type) as exc_info_2: + async with convert_exc: + await asyncio.sleep(0.1) + raise raised_exc_type() + + # Test sync decorator + with pytest.raises(expected_exc_type) as exc_info_3: + + @convert_exc + def help_func() -> NoReturn: + raise raised_exc_type() + + help_func() + + # Test async decorator + with pytest.raises(expected_exc_type) as exc_info_4: + + @convert_exc + async def help_func() -> NoReturn: + await asyncio.sleep(0.1) + raise raised_exc_type() + + await help_func() + + # Test exception cause + if should_be_converted: + assert all( + ( + type(exc_info_1.value.__cause__) is raised_exc_type, + type(exc_info_2.value.__cause__) is raised_exc_type, + type(exc_info_3.value.__cause__) is raised_exc_type, + type(exc_info_4.value.__cause__) is raised_exc_type, + ) + )