Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 29 additions & 31 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,14 +1079,16 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(enum_type=_core_types.EnumType(values=values))

def to_literal(
self, ctx: FlyteContext, python_val: enum.Enum, python_type: Type[T], expected: LiteralType
self, ctx: FlyteContext, python_val: Union[enum.Enum, str], python_type: Type[T], expected: LiteralType
) -> Literal:
if type(python_val).__class__ != enum.EnumMeta:
raise TypeTransformerFailedError("Expected an enum")
if type(python_val.value) != str:
raise TypeTransformerFailedError("Only string-valued enums are supported")

return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val.value))) # type: ignore
# Accept a raw string matching one of the enum's values: assert_type already allows it (e.g. an enum
# default supplied as a string), so to_literal must too, else such a value type-checks but fails to serialize.
Comment on lines +1084 to +1085

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those details should be on PR description rather than comment. Let's simply to:

Suggested change
# Accept a raw string matching one of the enum's values: assert_type already allows it (e.g. an enum
# default supplied as a string), so to_literal must too, else such a value type-checks but fails to serialize.
# Accept either an enum member or a raw string matching one of its values.

val = python_val.value if isinstance(python_val, enum.Enum) else python_val
if not isinstance(val, str):
raise TypeTransformerFailedError(f"Expected an enum or matching string, got {type(python_val)}")
if val not in [item.value for item in cast(Type[enum.Enum], python_type)]:
raise TypeTransformerFailedError(f"Value {python_val} is not in Enum {python_type}")
return Literal(scalar=Scalar(primitive=Primitive(string_value=val)))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
if lv.scalar and lv.scalar.binary:
Expand Down Expand Up @@ -2074,44 +2076,40 @@ async def async_to_literal(
) -> typing.Union[Literal, asyncio.Future]:
python_type = get_underlying_type(python_type)

potential_types = []
found_res = False
is_ambiguous = False
res = None
res_type = None
t = None
for i in range(len(get_args(python_type))):
# (type, literal, literal_type) for every variant that accepts the value
matches: typing.List[typing.Tuple[type, Literal, LiteralType]] = []
for i, t in enumerate(get_args(python_type)):
try:
t = get_args(python_type)[i]
trans: TypeTransformer[T] = TypeEngine.get_transformer(t)
if isinstance(trans, AsyncTypeTransformer):
attempt = trans.async_to_literal(ctx, python_val, t, expected.union_type.variants[i])
res = await attempt
res = await trans.async_to_literal(ctx, python_val, t, expected.union_type.variants[i])
else:
res = trans.to_literal(ctx, python_val, t, expected.union_type.variants[i])
if found_res:
logger.debug(f"Current type {get_args(python_type)[i]} old res {res_type}")
is_ambiguous = True
res_type = _add_tag_to_type(trans.get_literal_type(t), trans.name)
found_res = True
potential_types.append(t)
matches.append((t, res, res_type))
except Exception as e:
logger.debug(
f"UnionTransformer failed attempt to convert from {python_val} to {t} error: {e}",
)
continue

if is_ambiguous:
raise TypeError(
f"Ambiguous choice of variant for union type.\n"
f"Potential types: {potential_types}\n"
"These types are structurally the same, because it's attributes have the same names and associated types."
)
if not matches:
raise TypeTransformerFailedError(f"Cannot convert from {python_val} to {python_type}")

if found_res:
return Literal(scalar=Scalar(union=Union(value=res, stored_type=res_type)))
if len(matches) > 1:
# A bare string matches both `str` and an enum whose values include it -- prefer `str` (the exact type).
# Every other multi-variant match stays ambiguous: structurally-identical variants are a read-back hazard.
exact = [m for m in matches if m[0] is str] if type(python_val) is str else []
if len(exact) != 1:
raise TypeError(
"Ambiguous choice of variant for union type.\n"
f"Potential types: {[m[0] for m in matches]}\n"
"These types are structurally the same, because it's attributes have the same names and associated types."
)
matches = exact

raise TypeTransformerFailedError(f"Cannot convert from {python_val} to {python_type}")
_, res, res_type = matches[0]
return Literal(scalar=Scalar(union=Union(value=res, stored_type=res_type)))

async def async_to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]
Expand Down
55 changes: 43 additions & 12 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,43 @@
import sys
import tempfile
import typing
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict, dataclass, field
from datetime import timedelta
from enum import Enum, auto
from typing import List, Optional, Type, Dict
from typing import Dict, List, Optional, Type

import mock
import msgpack
import pytest
import typing_extensions
from concurrent.futures import ThreadPoolExecutor
from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import errors_pb2
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
from marshmallow_enum import LoadDumpOptions
from marshmallow_jsonschema import JSONSchema
from mashumaro.config import BaseConfig
from mashumaro.mixins.json import DataClassJSONMixin
from mashumaro.mixins.orjson import DataClassORJSONMixin
from mashumaro.types import Discriminator
from typing_extensions import Annotated, get_args, get_origin
from typing_extensions import Annotated, get_args

from flytekit import dynamic, kwtypes, task, workflow
from flytekit.core.annotation import FlyteAnnotation
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import flyte_tmp_dir
from flytekit.core.hash import HashMethod
from flytekit.core.type_engine import (
IntTransformer,
FloatTransformer,
BoolTransformer,
StrTransformer,
DataclassTransformer,
DictTransformer,
EnumTransformer,
FloatTransformer,
IntTransformer,
ListTransformer,
LiteralsResolver,
SimpleTransformer,
StrTransformer,
TypeEngine,
TypeTransformer,
TypeTransformerFailedError,
Expand All @@ -68,7 +67,7 @@
LiteralOffloadedMetadata,
Primitive,
Scalar,
Void, Binary,
Void,
)
from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType
from flytekit.types.directory import TensorboardLogs
Expand All @@ -79,11 +78,11 @@
from flytekit.types.file import FileExt, JPEGImageFile
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer, noop
from flytekit.types.iterator.iterator import IteratorTransformer
from flytekit.types.iterator.json_iterator import JSONIterator, JSONIteratorTransformer, JSON
from flytekit.types.iterator.json_iterator import JSON, JSONIterator, JSONIteratorTransformer
from flytekit.types.pickle import FlytePickle
from flytekit.types.pickle.pickle import FlytePickleTransformer
from flytekit.types.schema import FlyteSchema
from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine, PARQUET
from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine

T = typing.TypeVar("T")

Expand Down Expand Up @@ -1532,6 +1531,35 @@ def test_enum_type():
TypeEngine.to_literal_type(UnsupportedEnumValues)


def test_enum_to_literal_accepts_matching_string():
"""Test that a string matching an enum value is accepted by to_literal."""
ctx = FlyteContextManager.current_context()
lt = TypeEngine.to_literal_type(Color)

lv = TypeEngine.to_literal(ctx, "red", Color, lt)
assert lv.scalar.primitive.string_value == "red"


@pytest.mark.parametrize("python_val", ["purple", "Red", ""])
def test_enum_to_literal_rejects_non_matching_string(python_val):
ctx = FlyteContextManager.current_context()
lt = TypeEngine.to_literal_type(Color)

with pytest.raises(TypeTransformerFailedError):
TypeEngine.to_literal(ctx, python_val, Color, lt)


def test_enum_string_in_union_prefers_str():
"""Test that a bare string in Union[Color, str] resolves to str, not the enum, without ambiguity error."""
ctx = FlyteContextManager.current_context()
pt = typing.Union[Color, str]
lt = TypeEngine.to_literal_type(pt)

lv = TypeEngine.to_literal(ctx, "red", pt, lt)
assert lv.scalar.union.value.scalar.primitive.string_value == "red"
assert lv.scalar.union.stored_type.structure.tag == "str"


def test_multi_inheritance_enum_type():
tfm = TypeEngine.get_transformer(MultiInheritanceColor)
assert isinstance(tfm, EnumTransformer)
Expand Down Expand Up @@ -4112,8 +4140,9 @@ def test_asyncio_wait_empty_kwargs_regression():
"Set of Tasks/Futures is empty" ValueError.
"""
import asyncio
from flytekit.models import literals as _literal_models

from flytekit.core.type_engine import TypeEngine
from flytekit.models import literals as _literal_models

async def simulate_original_bug():
"""
Expand Down Expand Up @@ -4153,8 +4182,8 @@ def test_error_message_improvements_literal_map_to_kwargs():
Test that error messages in literal_map_to_kwargs use proper repr formatting
for better debugging experience.
"""
from flytekit.models import literals as _literal_models
from flytekit.core.type_engine import TypeTransformerFailedError
from flytekit.models import literals as _literal_models

ctx = FlyteContextManager.current_context()

Expand Down Expand Up @@ -4192,6 +4221,7 @@ def test_error_message_improvements_union_transformer():
for better debugging when conversion fails.
"""
from typing import Union

from flytekit.models import literals as _literal_models

ctx = FlyteContextManager.current_context()
Expand Down Expand Up @@ -4229,6 +4259,7 @@ def test_debug_logging_union_transformer(caplog):
"""
import logging
from typing import Union

from flytekit.models import literals as _literal_models

# Set logging level to capture debug messages
Expand Down
Loading