From 11c9eaf7aaa979d3b475d3b104c89884a64a1fe4 Mon Sep 17 00:00:00 2001 From: 1fanwang <1fannnw@gmail.com> Date: Thu, 11 Jun 2026 10:50:05 -0700 Subject: [PATCH 1/3] fix: EnumTransformer.to_literal accepts a string matching an enum value Signed-off-by: 1fanwang <1fannnw@gmail.com> --- flytekit/core/type_engine.py | 29 ++++++++-- tests/flytekit/unit/core/test_type_engine.py | 59 ++++++++++++++++---- 2 files changed, 71 insertions(+), 17 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 9993c98479..77b27886a3 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1081,6 +1081,14 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: def to_literal( self, ctx: FlyteContext, python_val: enum.Enum, python_type: Type[T], expected: LiteralType ) -> Literal: + # Accept a raw string that matches one of the enum's values. assert_type already + # allows this (e.g. an enum default supplied as a string), so to_literal must too, + # otherwise such a value passes type-checking but fails serialization. + if isinstance(python_val, str): + enum_type = cast(Type[enum.Enum], python_type) + if python_val not in [item.value for item in enum_type]: + raise TypeTransformerFailedError(f"Value {python_val} is not in Enum {python_type}") + return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val))) if type(python_val).__class__ != enum.EnumMeta: raise TypeTransformerFailedError("Expected an enum") if type(python_val.value) != str: @@ -2079,6 +2087,8 @@ async def async_to_literal( is_ambiguous = False res = None res_type = None + # Result of the ``str`` variant, if it matched, used to disambiguate a bare-string value. + str_match: typing.Optional[typing.Tuple[Literal, LiteralType]] = None t = None for i in range(len(get_args(python_type))): try: @@ -2095,6 +2105,8 @@ async def async_to_literal( res_type = _add_tag_to_type(trans.get_literal_type(t), trans.name) found_res = True potential_types.append(t) + if t is str: + str_match = (res, res_type) except Exception as e: logger.debug( f"UnionTransformer failed attempt to convert from {python_val} to {t} error: {e}", @@ -2102,11 +2114,18 @@ async def async_to_literal( continue if is_ambiguous: - raise TypeError( - f"Ambiguous choice of variant for union type.\n" - f"Potential types: {potential_types}\n" - "These types are structurally the same, because it's attributes have the same names and associated types." - ) + # A bare ``str`` is most specifically a ``str``: when the value is exactly a string + # and ``str`` is one of the variants, prefer it over variants that merely accept the + # string (e.g. an enum whose values include it). This narrowly disambiguates the + # ``str`` vs enum-by-string overlap without affecting other structurally-equal variants. + if type(python_val) is str and str_match is not None: + res, res_type = str_match + else: + raise TypeError( + f"Ambiguous choice of variant for union type.\n" + f"Potential types: {potential_types}\n" + "These types are structurally the same, because it's attributes have the same names and associated types." + ) if found_res: return Literal(scalar=Scalar(union=Union(value=res, stored_type=res_type))) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 8945ea46dd..73ce12d97e 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -6,27 +6,26 @@ import sys import tempfile import typing +from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict, dataclass, field from datetime import timedelta from enum import Enum, auto -from typing import List, Optional, Type, Dict +from typing import Dict, List, Optional, Type import mock import msgpack import pytest import typing_extensions -from concurrent.futures import ThreadPoolExecutor from dataclasses_json import DataClassJsonMixin, dataclass_json from flyteidl.core import errors_pb2 from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct -from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema from mashumaro.config import BaseConfig from mashumaro.mixins.json import DataClassJSONMixin from mashumaro.mixins.orjson import DataClassORJSONMixin from mashumaro.types import Discriminator -from typing_extensions import Annotated, get_args, get_origin +from typing_extensions import Annotated, get_args from flytekit import dynamic, kwtypes, task, workflow from flytekit.core.annotation import FlyteAnnotation @@ -34,16 +33,16 @@ from flytekit.core.data_persistence import flyte_tmp_dir from flytekit.core.hash import HashMethod from flytekit.core.type_engine import ( - IntTransformer, - FloatTransformer, BoolTransformer, - StrTransformer, DataclassTransformer, DictTransformer, EnumTransformer, + FloatTransformer, + IntTransformer, ListTransformer, LiteralsResolver, SimpleTransformer, + StrTransformer, TypeEngine, TypeTransformer, TypeTransformerFailedError, @@ -68,7 +67,7 @@ LiteralOffloadedMetadata, Primitive, Scalar, - Void, Binary, + Void, ) from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType from flytekit.types.directory import TensorboardLogs @@ -79,11 +78,11 @@ from flytekit.types.file import FileExt, JPEGImageFile from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer, noop from flytekit.types.iterator.iterator import IteratorTransformer -from flytekit.types.iterator.json_iterator import JSONIterator, JSONIteratorTransformer, JSON +from flytekit.types.iterator.json_iterator import JSON, JSONIterator, JSONIteratorTransformer from flytekit.types.pickle import FlytePickle from flytekit.types.pickle.pickle import FlytePickleTransformer from flytekit.types.schema import FlyteSchema -from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine, PARQUET +from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine T = typing.TypeVar("T") @@ -1532,6 +1531,39 @@ def test_enum_type(): TypeEngine.to_literal_type(UnsupportedEnumValues) +@pytest.mark.parametrize("python_val", ["red", "green", "blue"]) +def test_enum_to_literal_accepts_matching_string(python_val): + # A string matching an enum value is accepted by assert_type, so to_literal must + # accept it too (e.g. an enum default supplied as a string). + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(Color) + + lv = TypeEngine.to_literal(ctx, python_val, Color, lt) + assert lv.scalar.primitive.string_value == python_val + + +@pytest.mark.parametrize("python_val", ["purple", "Red", ""]) +def test_enum_to_literal_rejects_non_matching_string(python_val): + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(Color) + + with pytest.raises(TypeTransformerFailedError): + TypeEngine.to_literal(ctx, python_val, Color, lt) + + +def test_enum_string_in_union_prefers_str(): + # A bare string matching an enum value would otherwise match both the ``str`` and enum + # variants of a union. The string is most specifically a ``str``, so the ``str`` variant + # is chosen instead of raising an ambiguity error. + ctx = FlyteContextManager.current_context() + pt = typing.Union[Color, str] + lt = TypeEngine.to_literal_type(pt) + + lv = TypeEngine.to_literal(ctx, "red", pt, lt) + assert lv.scalar.union.value.scalar.primitive.string_value == "red" + assert lv.scalar.union.stored_type.structure.tag == "str" + + def test_multi_inheritance_enum_type(): tfm = TypeEngine.get_transformer(MultiInheritanceColor) assert isinstance(tfm, EnumTransformer) @@ -4112,8 +4144,9 @@ def test_asyncio_wait_empty_kwargs_regression(): "Set of Tasks/Futures is empty" ValueError. """ import asyncio - from flytekit.models import literals as _literal_models + from flytekit.core.type_engine import TypeEngine + from flytekit.models import literals as _literal_models async def simulate_original_bug(): """ @@ -4153,8 +4186,8 @@ def test_error_message_improvements_literal_map_to_kwargs(): Test that error messages in literal_map_to_kwargs use proper repr formatting for better debugging experience. """ - from flytekit.models import literals as _literal_models from flytekit.core.type_engine import TypeTransformerFailedError + from flytekit.models import literals as _literal_models ctx = FlyteContextManager.current_context() @@ -4192,6 +4225,7 @@ def test_error_message_improvements_union_transformer(): for better debugging when conversion fails. """ from typing import Union + from flytekit.models import literals as _literal_models ctx = FlyteContextManager.current_context() @@ -4229,6 +4263,7 @@ def test_debug_logging_union_transformer(caplog): """ import logging from typing import Union + from flytekit.models import literals as _literal_models # Set logging level to capture debug messages From 39c76dc58ee09722b9b55a37083822d1de42ca82 Mon Sep 17 00:00:00 2001 From: 1fanwang <1fannnw@gmail.com> Date: Tue, 30 Jun 2026 00:07:11 -0700 Subject: [PATCH 2/3] Address review feedback - Simplify EnumTransformer.to_literal: unify the enum/string paths through a single value extraction - UnionTransformer: replace the str-specific disambiguation with a general "exact type wins" tie-break when multiple variants match - Tests: de-parametrize the matching-string case, use docstrings, and update the custom-transformer sanity check for the exact-type tie-break Signed-off-by: 1fanwang <1fannnw@gmail.com> --- flytekit/core/type_engine.py | 75 +++++++------------- tests/flytekit/unit/core/test_type_engine.py | 19 +++-- 2 files changed, 35 insertions(+), 59 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 77b27886a3..555d571230 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1079,22 +1079,16 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType(enum_type=_core_types.EnumType(values=values)) def to_literal( - self, ctx: FlyteContext, python_val: enum.Enum, python_type: Type[T], expected: LiteralType + self, ctx: FlyteContext, python_val: Union[enum.Enum, str], python_type: Type[T], expected: LiteralType ) -> Literal: - # Accept a raw string that matches one of the enum's values. assert_type already - # allows this (e.g. an enum default supplied as a string), so to_literal must too, - # otherwise such a value passes type-checking but fails serialization. - if isinstance(python_val, str): - enum_type = cast(Type[enum.Enum], python_type) - if python_val not in [item.value for item in enum_type]: - raise TypeTransformerFailedError(f"Value {python_val} is not in Enum {python_type}") - return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val))) - if type(python_val).__class__ != enum.EnumMeta: - raise TypeTransformerFailedError("Expected an enum") - if type(python_val.value) != str: - raise TypeTransformerFailedError("Only string-valued enums are supported") - - return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val.value))) # type: ignore + # Accept a raw string matching one of the enum's values: assert_type already allows it (e.g. an enum + # default supplied as a string), so to_literal must too, else such a value type-checks but fails to serialize. + val = python_val.value if isinstance(python_val, enum.Enum) else python_val + if not isinstance(val, str): + raise TypeTransformerFailedError(f"Expected an enum or matching string, got {type(python_val)}") + if val not in [item.value for item in cast(Type[enum.Enum], python_type)]: + raise TypeTransformerFailedError(f"Value {python_val} is not in Enum {python_type}") + return Literal(scalar=Scalar(primitive=Primitive(string_value=val))) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: if lv.scalar and lv.scalar.binary: @@ -2082,55 +2076,40 @@ async def async_to_literal( ) -> typing.Union[Literal, asyncio.Future]: python_type = get_underlying_type(python_type) - potential_types = [] - found_res = False - is_ambiguous = False - res = None - res_type = None - # Result of the ``str`` variant, if it matched, used to disambiguate a bare-string value. - str_match: typing.Optional[typing.Tuple[Literal, LiteralType]] = None - t = None - for i in range(len(get_args(python_type))): + # (type, literal, literal_type) for every variant that accepts the value + matches: typing.List[typing.Tuple[type, Literal, LiteralType]] = [] + for i, t in enumerate(get_args(python_type)): try: - t = get_args(python_type)[i] trans: TypeTransformer[T] = TypeEngine.get_transformer(t) if isinstance(trans, AsyncTypeTransformer): - attempt = trans.async_to_literal(ctx, python_val, t, expected.union_type.variants[i]) - res = await attempt + res = await trans.async_to_literal(ctx, python_val, t, expected.union_type.variants[i]) else: res = trans.to_literal(ctx, python_val, t, expected.union_type.variants[i]) - if found_res: - logger.debug(f"Current type {get_args(python_type)[i]} old res {res_type}") - is_ambiguous = True res_type = _add_tag_to_type(trans.get_literal_type(t), trans.name) - found_res = True - potential_types.append(t) - if t is str: - str_match = (res, res_type) + matches.append((t, res, res_type)) except Exception as e: logger.debug( f"UnionTransformer failed attempt to convert from {python_val} to {t} error: {e}", ) continue - if is_ambiguous: - # A bare ``str`` is most specifically a ``str``: when the value is exactly a string - # and ``str`` is one of the variants, prefer it over variants that merely accept the - # string (e.g. an enum whose values include it). This narrowly disambiguates the - # ``str`` vs enum-by-string overlap without affecting other structurally-equal variants. - if type(python_val) is str and str_match is not None: - res, res_type = str_match - else: + if not matches: + raise TypeTransformerFailedError(f"Cannot convert from {python_val} to {python_type}") + + if len(matches) > 1: + # More than one variant matched. Prefer the one whose type is exactly the value's type -- e.g. for + # Union[str, Color] and "red", both str and Color (which accepts a matching string) match, but str wins. + exact = [m for m in matches if type(python_val) is m[0]] + if len(exact) != 1: raise TypeError( - f"Ambiguous choice of variant for union type.\n" - f"Potential types: {potential_types}\n" + "Ambiguous choice of variant for union type.\n" + f"Potential types: {[m[0] for m in matches]}\n" "These types are structurally the same, because it's attributes have the same names and associated types." ) + matches = exact - if found_res: - return Literal(scalar=Scalar(union=Union(value=res, stored_type=res_type))) - - raise TypeTransformerFailedError(f"Cannot convert from {python_val} to {python_type}") + _, res, res_type = matches[0] + return Literal(scalar=Scalar(union=Union(value=res, stored_type=res_type))) async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 73ce12d97e..0042820f15 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1531,15 +1531,13 @@ def test_enum_type(): TypeEngine.to_literal_type(UnsupportedEnumValues) -@pytest.mark.parametrize("python_val", ["red", "green", "blue"]) -def test_enum_to_literal_accepts_matching_string(python_val): - # A string matching an enum value is accepted by assert_type, so to_literal must - # accept it too (e.g. an enum default supplied as a string). +def test_enum_to_literal_accepts_matching_string(): + """Test that a string matching an enum value is accepted by to_literal.""" ctx = FlyteContextManager.current_context() lt = TypeEngine.to_literal_type(Color) - lv = TypeEngine.to_literal(ctx, python_val, Color, lt) - assert lv.scalar.primitive.string_value == python_val + lv = TypeEngine.to_literal(ctx, "red", Color, lt) + assert lv.scalar.primitive.string_value == "red" @pytest.mark.parametrize("python_val", ["purple", "Red", ""]) @@ -1552,9 +1550,7 @@ def test_enum_to_literal_rejects_non_matching_string(python_val): def test_enum_string_in_union_prefers_str(): - # A bare string matching an enum value would otherwise match both the ``str`` and enum - # variants of a union. The string is most specifically a ``str``, so the ``str`` variant - # is chosen instead of raising an ambiguity error. + """Test that a bare string in Union[Color, str] resolves to str, not the enum, without ambiguity error.""" ctx = FlyteContextManager.current_context() pt = typing.Union[Color, str] lt = TypeEngine.to_literal_type(pt) @@ -2051,8 +2047,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: assert union_type_tags_unique(lt) ctx = FlyteContextManager.current_context() - with pytest.raises(TypeError, match="Ambiguous choice of variant for union type"): - TypeEngine.to_literal(ctx, 3, pt, lt) + # int and UnsignedInt both accept 3, but int is the value's exact type, so it wins without ambiguity. + lv = TypeEngine.to_literal(ctx, 3, pt, lt) + assert lv.scalar.union.stored_type.structure.tag == "int" del TypeEngine._REGISTRY[UnsignedInt] From 033d1cc94ffea737c286ab269fe92e4837126cfd Mon Sep 17 00:00:00 2001 From: 1fanwang <1fannnw@gmail.com> Date: Tue, 30 Jun 2026 00:24:26 -0700 Subject: [PATCH 3/3] Narrow union disambiguation to str, preserving the structural-ambiguity guard The general exact-type tie-break removed the guard that makes structurally- identical union variants (e.g. Union[A, B] dataclasses, Union[int, UnsignedInt]) raise on to_literal -- they would then be indistinguishable on read-back. Keep that guard; only a bare string prefers the str variant over an enum-by-value. Signed-off-by: 1fanwang <1fannnw@gmail.com> --- flytekit/core/type_engine.py | 6 +++--- tests/flytekit/unit/core/test_type_engine.py | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 555d571230..3ffb554ac9 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -2097,9 +2097,9 @@ async def async_to_literal( raise TypeTransformerFailedError(f"Cannot convert from {python_val} to {python_type}") if len(matches) > 1: - # More than one variant matched. Prefer the one whose type is exactly the value's type -- e.g. for - # Union[str, Color] and "red", both str and Color (which accepts a matching string) match, but str wins. - exact = [m for m in matches if type(python_val) is m[0]] + # A bare string matches both `str` and an enum whose values include it -- prefer `str` (the exact type). + # Every other multi-variant match stays ambiguous: structurally-identical variants are a read-back hazard. + exact = [m for m in matches if m[0] is str] if type(python_val) is str else [] if len(exact) != 1: raise TypeError( "Ambiguous choice of variant for union type.\n" diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 0042820f15..de3f7dfd02 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -2047,9 +2047,8 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: assert union_type_tags_unique(lt) ctx = FlyteContextManager.current_context() - # int and UnsignedInt both accept 3, but int is the value's exact type, so it wins without ambiguity. - lv = TypeEngine.to_literal(ctx, 3, pt, lt) - assert lv.scalar.union.stored_type.structure.tag == "int" + with pytest.raises(TypeError, match="Ambiguous choice of variant for union type"): + TypeEngine.to_literal(ctx, 3, pt, lt) del TypeEngine._REGISTRY[UnsignedInt]