From 8c0a0f3eb87da35647169d2649e2ff37ff735512 Mon Sep 17 00:00:00 2001 From: Darsh Patel Date: Thu, 5 Mar 2026 11:26:11 -0800 Subject: [PATCH] Fix mypy arg-type errors in generated discriminated union encoders The codegen for discriminated union TypedDict encoders generates a ternary chain where mypy can't narrow the union type through key-based conditions. Instead of suppressing with `# type: ignore[arg-type]`, use `cast()` to explicitly narrow the type to the correct variant after the discriminator check. This preserves type safety in the generated code. Affects both the single-variant and multi-variant discriminator code paths. --- src/replit_river/codegen/client.py | 8 ++++++-- .../v1/codegen/rpc/generated/test_service/rpc_method.py | 1 + .../generated_special_chars/test_service/rpc_method.py | 1 + .../test_service/anyof_mixed_method.py | 1 + .../test_basic_stream/test_service/emit_error.py | 1 + .../test_basic_stream/test_service/stream_method.py | 1 + .../test_service/pathological_method.py | 1 + .../test_recursive_types/recursiveService/getTree.py | 1 + .../snapshots/test_unknown_enum/enumService/needsEnum.py | 1 + .../test_unknown_enum/enumService/needsEnumObject.py | 9 +++++++-- .../snapshots/test_basic_rpc/test_service/rpc_method.py | 1 + .../test_basic_stream/test_service/stream_method.py | 1 + 12 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index d59936fe..a760acee 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -80,6 +80,7 @@ Mapping, NotRequired, TypedDict, + cast, ) from typing_extensions import Annotated @@ -301,9 +302,10 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: # "encoder_names" is only a TypedDict thing encoder_names.add(encoder_name) _field_name = render_literal_type(encoder_name) + _type_name = render_literal_type(type_name) typeddict_encoder.append( f"""\ - {_field_name}(x) # type: ignore[arg-type] + {_field_name}(cast('{_type_name}', x)) """.strip() ) if local_discriminators: @@ -333,8 +335,10 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: # TODO(dstewart): Figure out why uncommenting this breaks # generated code # encoder_names.add(encoder_name) + _type_name = render_literal_type(type_name) typeddict_encoder.append( - f"{render_literal_type(encoder_name)}(x)" + f"{render_literal_type(encoder_name)}" + f"(cast('{_type_name}', x))" ) typeddict_encoder.append( f""" diff --git a/tests/v1/codegen/rpc/generated/test_service/rpc_method.py b/tests/v1/codegen/rpc/generated/test_service/rpc_method.py index d839f4af..96fb2b79 100644 --- a/tests/v1/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/v1/codegen/rpc/generated/test_service/rpc_method.py @@ -7,6 +7,7 @@ Mapping, NotRequired, TypedDict, + cast, ) from typing_extensions import Annotated diff --git a/tests/v1/codegen/rpc/generated_special_chars/test_service/rpc_method.py b/tests/v1/codegen/rpc/generated_special_chars/test_service/rpc_method.py index 77efaee5..fe1028d6 100644 --- a/tests/v1/codegen/rpc/generated_special_chars/test_service/rpc_method.py +++ b/tests/v1/codegen/rpc/generated_special_chars/test_service/rpc_method.py @@ -7,6 +7,7 @@ Mapping, NotRequired, TypedDict, + cast, ) from typing_extensions import Annotated diff --git a/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/anyof_mixed_method.py b/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/anyof_mixed_method.py index 4846c795..dce862c0 100644 --- a/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/anyof_mixed_method.py +++ b/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/anyof_mixed_method.py @@ -7,6 +7,7 @@ Mapping, NotRequired, TypedDict, + cast, ) from typing_extensions import Annotated diff --git a/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py b/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py index e7005c29..f68957db 100644 --- a/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py +++ b/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py @@ -7,6 +7,7 @@ Mapping, NotRequired, TypedDict, + cast, ) from typing_extensions import Annotated diff --git a/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py b/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py index 1914aefc..fb84de18 100644 --- a/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py +++ b/tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py @@ -7,6 +7,7 @@ Mapping, NotRequired, TypedDict, + cast, ) from typing_extensions import Annotated diff --git a/tests/v1/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py b/tests/v1/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py index 2c325e64..7979be13 100644 --- a/tests/v1/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py +++ b/tests/v1/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py @@ -7,6 +7,7 @@ Mapping, NotRequired, TypedDict, + cast, ) from typing_extensions import Annotated diff --git a/tests/v1/codegen/snapshot/snapshots/test_recursive_types/recursiveService/getTree.py b/tests/v1/codegen/snapshot/snapshots/test_recursive_types/recursiveService/getTree.py index b6fd551d..12ed0eb0 100644 --- a/tests/v1/codegen/snapshot/snapshots/test_recursive_types/recursiveService/getTree.py +++ b/tests/v1/codegen/snapshot/snapshots/test_recursive_types/recursiveService/getTree.py @@ -7,6 +7,7 @@ Mapping, NotRequired, TypedDict, + cast, ) from typing_extensions import Annotated diff --git a/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index 69b976b5..40adc47a 100644 --- a/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -7,6 +7,7 @@ Mapping, NotRequired, TypedDict, + cast, ) from typing_extensions import Annotated diff --git a/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index e875810a..0f0b0d8b 100644 --- a/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -7,6 +7,7 @@ Mapping, NotRequired, TypedDict, + cast, ) from typing_extensions import Annotated @@ -71,9 +72,13 @@ def encode_NeedsenumobjectInput( x: "NeedsenumobjectInput", ) -> Any: return ( - encode_NeedsenumobjectInputOneOf_in_first(x) + encode_NeedsenumobjectInputOneOf_in_first( + cast("NeedsenumobjectInputOneOf_in_first", x) + ) if x["kind"] == "in_first" - else encode_NeedsenumobjectInputOneOf_in_second(x) + else encode_NeedsenumobjectInputOneOf_in_second( + cast("NeedsenumobjectInputOneOf_in_second", x) + ) ) diff --git a/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/rpc_method.py b/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/rpc_method.py index 5fe764eb..139c15d7 100644 --- a/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/rpc_method.py +++ b/tests/v2/codegen/snapshot/snapshots/test_basic_rpc/test_service/rpc_method.py @@ -7,6 +7,7 @@ Mapping, NotRequired, TypedDict, + cast, ) from typing_extensions import Annotated diff --git a/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py index 9a0fc5d0..3dd84f71 100644 --- a/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py +++ b/tests/v2/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py @@ -7,6 +7,7 @@ Mapping, NotRequired, TypedDict, + cast, ) from typing_extensions import Annotated