Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions stubs/grpcio/@tests/test_cases/check_server_interceptor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from collections.abc import Callable
from collections.abc import AsyncIterator, Callable
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Awaitable, TypeVar
from typing import Awaitable, TypeVar, cast

import grpc
import grpc.aio
Expand All @@ -26,10 +26,18 @@ def intercept_service(
class NoopAioInterceptor(grpc.aio.ServerInterceptor):
async def intercept_service(
self,
continuation: Callable[[grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler[RequestT, ResponseT]]],
continuation: Callable[[grpc.HandlerCallDetails], Awaitable[grpc.aio.RpcMethodHandler[RequestT, ResponseT]]],
handler_call_details: grpc.HandlerCallDetails,
) -> grpc.RpcMethodHandler[RequestT, ResponseT]:
) -> grpc.aio.RpcMethodHandler[RequestT, ResponseT]:
return await continuation(handler_call_details)


grpc.aio.server(interceptors=[NoopAioInterceptor()])

aio_handler = cast(grpc.aio.RpcMethodHandler[bytes, bytes], object())
aio_handler.unary_unary = cast(Callable[[bytes, grpc.aio.ServicerContext[bytes, bytes]], Awaitable[bytes]], None)
aio_handler.unary_stream = cast(Callable[[bytes, grpc.aio.ServicerContext[bytes, bytes]], AsyncIterator[bytes]], None)
aio_handler.stream_unary = cast(Callable[[AsyncIterator[bytes], grpc.aio.ServicerContext[bytes, bytes]], Awaitable[bytes]], None)
aio_handler.stream_stream = cast(
Callable[[AsyncIterator[bytes], grpc.aio.ServicerContext[bytes, bytes]], AsyncIterator[bytes]], None
)
16 changes: 15 additions & 1 deletion stubs/grpcio/grpc/aio/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ from grpc import (
GenericRpcHandler,
HandlerCallDetails,
RpcError,
RpcMethodHandler,
ServerCredentials,
StatusCode,
_Options,
Expand Down Expand Up @@ -369,6 +368,21 @@ class StreamStreamClientInterceptor(ClientInterceptor, metaclass=abc.ABCMeta):
request_iterator: AsyncIterable[_TRequest] | Iterable[_TRequest],
) -> AsyncIterator[_TResponse] | StreamStreamCall[_TRequest, _TResponse]: ...

# Service-Side Handler:

@type_check_only
class RpcMethodHandler(Generic[_TRequest, _TResponse]):
request_streaming: bool
response_streaming: bool

request_deserializer: _Deserializer[_TRequest] | None
response_serializer: _Serializer[_TResponse] | None

unary_unary: Callable[[_TRequest, ServicerContext[_TRequest, _TResponse]], Awaitable[_TResponse]] | None
unary_stream: Callable[[_TRequest, ServicerContext[_TRequest, _TResponse]], AsyncIterator[_TResponse]] | None
stream_unary: Callable[[AsyncIterator[_TRequest], ServicerContext[_TRequest, _TResponse]], Awaitable[_TResponse]] | None
stream_stream: Callable[[AsyncIterator[_TRequest], ServicerContext[_TRequest, _TResponse]], AsyncIterator[_TResponse]] | None

# Server-Side Interceptor:

class ServerInterceptor(metaclass=abc.ABCMeta):
Expand Down