feat[next]: Tracer support part 1: tree_map#2586
Conversation
…porting nesting (extracted from GridTools#2487)
There was a problem hiding this comment.
Pull request overview
This PR introduces an iterator-level tree_map builtin as an IR operator to support mapping functions over (nested) tuples, as a first step towards tracer support and future vector operations. It includes type synthesis for tree_map and a transform (UnrollTreeMap) that lowers tree_map(f)(...) into explicit make_tuple / tuple_get IR.
Changes:
- Add
tree_mapbuiltin plumbing (builtin dispatch + IR maker helper) and update tuple-wherelowering to emittree_map. - Add
tree_maptype synthesizer and a newUnrollTreeMaptransform, wired into the iterator pass pipeline. - Add unit tests for the
_unrollhelper and adjust existing frontend lowering expectations.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py | Adds unit tests for _unroll tuple expansion behavior. |
| tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py | Updates tuple where reference IR to use tree_map. |
| src/gt4py/next/iterator/type_system/type_synthesizer.py | Registers and implements type synthesis for tree_map. |
| src/gt4py/next/iterator/transforms/unroll_tree_map.py | New transform to unroll tree_map into tuple primitives. |
| src/gt4py/next/iterator/transforms/pass_manager.py | Runs UnrollTreeMap and tuple-collapsing before domain inference. |
| src/gt4py/next/iterator/ir_utils/ir_makers.py | Adds im.tree_map(...) helper for constructing IR. |
| src/gt4py/next/iterator/builtins.py | Adds tree_map to builtin dispatch and builtin name set. |
| src/gt4py/next/ffront/foast_to_gtir.py | Lowers tuple where via tree_map instead of explicit tuple construction. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # After UnrollTreeMap, collapse `tuple_get(i, let(...)(make_tuple(...)))` patterns so that | ||
| # domain inference does not encounter `as_fieldop` nodes inside dead tuple elements | ||
| # (which would receive NEVER domain). Do multiple iterations for nested `let`s. | ||
| for _ in range(10): | ||
| collapsed = ir | ||
| ir = CollapseTuple.apply( | ||
| ir, | ||
| enabled_transformations=( | ||
| CollapseTuple.Transformation.PROPAGATE_TUPLE_GET | ||
| | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE | ||
| ), | ||
| uids=uids, | ||
| offset_provider_type=offset_provider_type, | ||
| ) # type: ignore[assignment] # always an itir.Program | ||
| if ir == collapsed: | ||
| break | ||
| else: | ||
| raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.") |
There was a problem hiding this comment.
Without this test_reduction_expression_with_where_and_tuples fails with ValueError: 'target_domain' cannot be 'NEVER' unless "allow_uninferred=True".
There was a problem hiding this comment.
Note: probably this is also the test case where the loop is required. I'll take a look if another configuration of the pass helps to avoid the loop.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| f"Call to object of type '{type(node.func.type).__name__}' not understood." | ||
| ) | ||
|
|
||
| def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: |
There was a problem hiding this comment.
I think _visit_astype cannot use tree_map because it needs type-dependent lowering per leaf: fields use _map(cast) while scalars use cast directly. Let's discuss if you have something else in mind.
|
|
||
| return im.let(cond_symref_name, cond_)(result) | ||
|
|
||
| def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: |
There was a problem hiding this comment.
As far as I see, _visit_concat_where already has its own expand_tuple_args pass for handling nested tuples, and each branch can have a different domain. Attempting to wrap it in tree_map caused type inference failures. Let's discuss if you have something else in mind.
There was a problem hiding this comment.
Other reasons: A single concat_where is better to digest for optimizations. The lowering would actually be complicated if we would emit tree_map in foast_to_gtir.
havogt
left a comment
There was a problem hiding this comment.
Two findings from reviewing this PR and verifying locally against the branch (simple_mesh, all backends). Details inline.
| (node.args[0].type, *arg_types), | ||
| result = im.tree_map_tuple( | ||
| im.lambda_("__a", "__b")( | ||
| im.op_as_fieldop("if_")(im.ref(cond_symref_name), im.ref("__a"), im.ref("__b")) |
There was a problem hiding this comment.
Tuple-where over a per-neighbour (local) condition produces invalid IR.
Hardcoding the per-leaf op to op_as_fieldop("if_") drops the local-field handling the non-tuple path keeps: _lower_and_map("if_", …) → _map wraps the op in map_list(if_) and promotes the condition with make_const_list when the leaves are local fields. The single, uniform tree_map_tuple leaf can't do that.
This is reachable from the frontend and breaks on roundtrip, gtfn and dace whenever the where condition is itself a local (per-neighbour) field — the mask only has to be a FieldType, and a local field qualifies:
@gtx.field_operator
def tup(a: EdgeF, b: EdgeF, c: EdgeF, d: EdgeF) -> tuple[VertF, VertF]:
cond = a(V2E) > c(V2E) # per-neighbour (local) bool field
t = where(cond, (a(V2E), b(V2E)), (c(V2E), d(V2E)))
return (neighbor_sum(t[0], axis=V2EDim), neighbor_sum(t[1], axis=V2EDim))The non-tuple equivalent (neighbor_sum(where(a(V2E)>c(V2E), a(V2E), c(V2E)), …)) works on all three backends; the tuple version fails on all three with AssertionError in the if_ type synthesizer (isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL), because after UnrollTupleMaps the leaf is ⇑(λ(c,x,y) → if ·c then ·x else ·y) applied to a list-typed predicate — no map_list.
Broadcast (non-local) conditions happen to pass, since wholesale if_(scalar, list, list) equals element-wise selection and the consuming reduce absorbs the list — only a varying condition exposes it.
This is the same per-leaf type-dependent reason given for not migrating _visit_astype. It also isn't fixable by "always emit map_list": with a non-local cond and mixed leaves (one branch local, one not), no single uniform leaf lambda is correct. Options: keep where on process_elements, or wrap if_ in map_list post-unroll when the predicate is a list.
| return op(*leaf_types, offset_provider_type=offset_provider_type) # type: ignore[return-value] | ||
|
|
||
| if recursive: | ||
| return utils.tree_map( # type: ignore[return-value] |
There was a problem hiding this comment.
Mismatched tuple structure raises a bare AssertionError instead of a clear TypeError.
The check above guarantees all args are TupleType, but not that they share arity/nesting. When they differ, utils.tree_map trips its internal assert ... len(args[0]) == len(arg) (or the all-collection assert during recursion) and raises an AssertionError with no message:
im.tree_map_tuple(im.ref("plus"))(
im.ref("t1", ts.TupleType(types=[int_t, int_t, int_t])),
im.ref("t2", ts.TupleType(types=[int_t, int_t])),
) # -> AssertionError (empty message); also for mismatched nestingNot reachable from the frontend today — where pre-validates branch structure with a DSLError — but it'll bite the upcoming map_tuple/tracer producers and anyone building IR directly. A TypeError here (and in the matching spot in UnrollTupleMaps) would be friendlier.
There was a problem hiding this comment.
The following lowering strategy does not work right now:
(2*el for el in (a(V2E), b)) -> map_tuple(lambda el: op_as_fieldop(*)(2, el))({a_field_with_local_dim, b_field_of_scalars}) since the el symref node can only have one type and since we can not distinguish (i.e. no isinstance on GTIR level) between them.
However since we know the length we can use the process_elements approach like for where. In case we don't know the length the element type is always the same and we can use map_tuple in the lowering and then fold it on GTIR level.
(2*el for el in var_length_tuple_scalar) -> map_tuple(lambda el: 2*el)(var_length_tuple)
(2*el for el in var_length_tuple_of_field) -> map_tuple(lambda el: as_fieldop(lambda el_at_grid_point: 2*el_at_grid_point)(el))(var_length_tuple)
(2*el for el in var_length_tuple_field_with_local_dim) -> map_tuple(lambda el: map_list(lambda in_el: 2*in_el)(el))(tuple_of_it_of_list)
Conclusion:
We continue / merge this PR (in the previous form and after review). In #2487 we can then use map_tuple in the lowering for the variable length cases and only need process elements for the fixed length case. We also need tests to cover all cases (i.e. testing for where, concat_where, tuple comprehension of elements with mixed type like (a(V2E), b)).
As a first step towards tracer support (enabling vector operations), this PR introduces a new
tree_mapoperator for mapping functions over tuples (including nesting).tree_mapis unrolled tomake_tuplecalls inUnrollTreeMap.