diff --git a/stubs/grpcio/@tests/test_cases/check_server_interceptor.py b/stubs/grpcio/@tests/test_cases/check_server_interceptor.py index d5ef592a61ad..07d4e89e9d1c 100644 --- a/stubs/grpcio/@tests/test_cases/check_server_interceptor.py +++ b/stubs/grpcio/@tests/test_cases/check_server_interceptor.py @@ -1,8 +1,8 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import AsyncIterator, Callable from concurrent.futures.thread import ThreadPoolExecutor -from typing import Awaitable, TypeVar +from typing import Awaitable, TypeVar, cast import grpc import grpc.aio @@ -26,10 +26,18 @@ def intercept_service( class NoopAioInterceptor(grpc.aio.ServerInterceptor): async def intercept_service( self, - continuation: Callable[[grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler[RequestT, ResponseT]]], + continuation: Callable[[grpc.HandlerCallDetails], Awaitable[grpc.aio.RpcMethodHandler[RequestT, ResponseT]]], handler_call_details: grpc.HandlerCallDetails, - ) -> grpc.RpcMethodHandler[RequestT, ResponseT]: + ) -> grpc.aio.RpcMethodHandler[RequestT, ResponseT]: return await continuation(handler_call_details) grpc.aio.server(interceptors=[NoopAioInterceptor()]) + +aio_handler = cast(grpc.aio.RpcMethodHandler[bytes, bytes], object()) +aio_handler.unary_unary = cast(Callable[[bytes, grpc.aio.ServicerContext[bytes, bytes]], Awaitable[bytes]], None) +aio_handler.unary_stream = cast(Callable[[bytes, grpc.aio.ServicerContext[bytes, bytes]], AsyncIterator[bytes]], None) +aio_handler.stream_unary = cast(Callable[[AsyncIterator[bytes], grpc.aio.ServicerContext[bytes, bytes]], Awaitable[bytes]], None) +aio_handler.stream_stream = cast( + Callable[[AsyncIterator[bytes], grpc.aio.ServicerContext[bytes, bytes]], AsyncIterator[bytes]], None +) diff --git a/stubs/grpcio/grpc/aio/__init__.pyi b/stubs/grpcio/grpc/aio/__init__.pyi index 4c8ef012fb29..a2f42a9e631b 100644 --- a/stubs/grpcio/grpc/aio/__init__.pyi +++ b/stubs/grpcio/grpc/aio/__init__.pyi @@ -15,7 +15,6 @@ from grpc import ( GenericRpcHandler, HandlerCallDetails, RpcError, - RpcMethodHandler, ServerCredentials, StatusCode, _Options, @@ -369,6 +368,21 @@ class StreamStreamClientInterceptor(ClientInterceptor, metaclass=abc.ABCMeta): request_iterator: AsyncIterable[_TRequest] | Iterable[_TRequest], ) -> AsyncIterator[_TResponse] | StreamStreamCall[_TRequest, _TResponse]: ... +# Service-Side Handler: + +@type_check_only +class RpcMethodHandler(Generic[_TRequest, _TResponse]): + request_streaming: bool + response_streaming: bool + + request_deserializer: _Deserializer[_TRequest] | None + response_serializer: _Serializer[_TResponse] | None + + unary_unary: Callable[[_TRequest, ServicerContext[_TRequest, _TResponse]], Awaitable[_TResponse]] | None + unary_stream: Callable[[_TRequest, ServicerContext[_TRequest, _TResponse]], AsyncIterator[_TResponse]] | None + stream_unary: Callable[[AsyncIterator[_TRequest], ServicerContext[_TRequest, _TResponse]], Awaitable[_TResponse]] | None + stream_stream: Callable[[AsyncIterator[_TRequest], ServicerContext[_TRequest, _TResponse]], AsyncIterator[_TResponse]] | None + # Server-Side Interceptor: class ServerInterceptor(metaclass=abc.ABCMeta):