Skip to content

feat[next]: Tracer support part 1: tree_map#2586

Open
SF-N wants to merge 29 commits into
GridTools:mainfrom
SF-N:tracer_support_tree_map
Open

feat[next]: Tracer support part 1: tree_map#2586
SF-N wants to merge 29 commits into
GridTools:mainfrom
SF-N:tracer_support_tree_map

Conversation

@SF-N

@SF-N SF-N commented Apr 27, 2026

Copy link
Copy Markdown
Contributor

As a first step towards tracer support (enabling vector operations), this PR introduces a new tree_map operator for mapping functions over tuples (including nesting). tree_map is unrolled to make_tuple calls in UnrollTreeMap.

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces an iterator-level tree_map builtin as an IR operator to support mapping functions over (nested) tuples, as a first step towards tracer support and future vector operations. It includes type synthesis for tree_map and a transform (UnrollTreeMap) that lowers tree_map(f)(...) into explicit make_tuple / tuple_get IR.

Changes:

  • Add tree_map builtin plumbing (builtin dispatch + IR maker helper) and update tuple-where lowering to emit tree_map.
  • Add tree_map type synthesizer and a new UnrollTreeMap transform, wired into the iterator pass pipeline.
  • Add unit tests for the _unroll helper and adjust existing frontend lowering expectations.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py Adds unit tests for _unroll tuple expansion behavior.
tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py Updates tuple where reference IR to use tree_map.
src/gt4py/next/iterator/type_system/type_synthesizer.py Registers and implements type synthesis for tree_map.
src/gt4py/next/iterator/transforms/unroll_tree_map.py New transform to unroll tree_map into tuple primitives.
src/gt4py/next/iterator/transforms/pass_manager.py Runs UnrollTreeMap and tuple-collapsing before domain inference.
src/gt4py/next/iterator/ir_utils/ir_makers.py Adds im.tree_map(...) helper for constructing IR.
src/gt4py/next/iterator/builtins.py Adds tree_map to builtin dispatch and builtin name set.
src/gt4py/next/ffront/foast_to_gtir.py Lowers tuple where via tree_map instead of explicit tuple construction.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/gt4py/next/iterator/transforms/unroll_tree_map.py Outdated
Comment thread src/gt4py/next/iterator/transforms/pass_manager.py Outdated
Comment thread src/gt4py/next/iterator/transforms/pass_manager.py Outdated
Comment thread src/gt4py/next/iterator/type_system/type_synthesizer.py Outdated
Comment thread src/gt4py/next/iterator/type_system/type_synthesizer.py Outdated
Comment on lines +182 to +199
# After UnrollTreeMap, collapse `tuple_get(i, let(...)(make_tuple(...)))` patterns so that
# domain inference does not encounter `as_fieldop` nodes inside dead tuple elements
# (which would receive NEVER domain). Do multiple iterations for nested `let`s.
for _ in range(10):
collapsed = ir
ir = CollapseTuple.apply(
ir,
enabled_transformations=(
CollapseTuple.Transformation.PROPAGATE_TUPLE_GET
| CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE
),
uids=uids,
offset_provider_type=offset_provider_type,
) # type: ignore[assignment] # always an itir.Program
if ir == collapsed:
break
else:
raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this test_reduction_expression_with_where_and_tuples fails with ValueError: 'target_domain' cannot be 'NEVER' unless "allow_uninferred=True".

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: probably this is also the test case where the loop is required. I'll take a look if another configuration of the pass helps to avoid the loop.

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/gt4py/next/iterator/transforms/unroll_tree_map.py Outdated
Comment thread src/gt4py/next/iterator/transforms/pass_manager.py Outdated
Comment thread src/gt4py/next/iterator/type_system/type_synthesizer.py Outdated
Comment thread src/gt4py/next/iterator/transforms/unroll_tree_map.py Outdated
f"Call to object of type '{type(node.func.type).__name__}' not understood."
)

def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr:

@SF-N SF-N Apr 28, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think _visit_astype cannot use tree_map because it needs type-dependent lowering per leaf: fields use _map(cast) while scalars use cast directly. Let's discuss if you have something else in mind.


return im.let(cond_symref_name, cond_)(result)

def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I see, _visit_concat_where already has its own expand_tuple_args pass for handling nested tuples, and each branch can have a different domain. Attempting to wrap it in tree_map caused type inference failures. Let's discuss if you have something else in mind.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other reasons: A single concat_where is better to digest for optimizations. The lowering would actually be complicated if we would emit tree_map in foast_to_gtir.

Comment thread src/gt4py/next/iterator/transforms/unroll_tree_map.py Outdated
Comment thread src/gt4py/next/iterator/type_system/type_synthesizer.py Outdated
Comment thread src/gt4py/next/iterator/type_system/type_synthesizer.py Outdated
@SF-N SF-N requested a review from tehrengruber April 29, 2026 13:52

@havogt havogt left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two findings from reviewing this PR and verifying locally against the branch (simple_mesh, all backends). Details inline.

(node.args[0].type, *arg_types),
result = im.tree_map_tuple(
im.lambda_("__a", "__b")(
im.op_as_fieldop("if_")(im.ref(cond_symref_name), im.ref("__a"), im.ref("__b"))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tuple-where over a per-neighbour (local) condition produces invalid IR.

Hardcoding the per-leaf op to op_as_fieldop("if_") drops the local-field handling the non-tuple path keeps: _lower_and_map("if_", …)_map wraps the op in map_list(if_) and promotes the condition with make_const_list when the leaves are local fields. The single, uniform tree_map_tuple leaf can't do that.

This is reachable from the frontend and breaks on roundtrip, gtfn and dace whenever the where condition is itself a local (per-neighbour) field — the mask only has to be a FieldType, and a local field qualifies:

@gtx.field_operator
def tup(a: EdgeF, b: EdgeF, c: EdgeF, d: EdgeF) -> tuple[VertF, VertF]:
    cond = a(V2E) > c(V2E)                               # per-neighbour (local) bool field
    t = where(cond, (a(V2E), b(V2E)), (c(V2E), d(V2E)))
    return (neighbor_sum(t[0], axis=V2EDim), neighbor_sum(t[1], axis=V2EDim))

The non-tuple equivalent (neighbor_sum(where(a(V2E)>c(V2E), a(V2E), c(V2E)), …)) works on all three backends; the tuple version fails on all three with AssertionError in the if_ type synthesizer (isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL), because after UnrollTupleMaps the leaf is ⇑(λ(c,x,y) → if ·c then ·x else ·y) applied to a list-typed predicate — no map_list.

Broadcast (non-local) conditions happen to pass, since wholesale if_(scalar, list, list) equals element-wise selection and the consuming reduce absorbs the list — only a varying condition exposes it.

This is the same per-leaf type-dependent reason given for not migrating _visit_astype. It also isn't fixable by "always emit map_list": with a non-local cond and mixed leaves (one branch local, one not), no single uniform leaf lambda is correct. Options: keep where on process_elements, or wrap if_ in map_list post-unroll when the predicate is a list.

return op(*leaf_types, offset_provider_type=offset_provider_type) # type: ignore[return-value]

if recursive:
return utils.tree_map( # type: ignore[return-value]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mismatched tuple structure raises a bare AssertionError instead of a clear TypeError.

The check above guarantees all args are TupleType, but not that they share arity/nesting. When they differ, utils.tree_map trips its internal assert ... len(args[0]) == len(arg) (or the all-collection assert during recursion) and raises an AssertionError with no message:

im.tree_map_tuple(im.ref("plus"))(
    im.ref("t1", ts.TupleType(types=[int_t, int_t, int_t])),
    im.ref("t2", ts.TupleType(types=[int_t, int_t])),
)   # -> AssertionError (empty message); also for mismatched nesting

Not reachable from the frontend today — where pre-validates branch structure with a DSLError — but it'll bite the upcoming map_tuple/tracer producers and anyone building IR directly. A TypeError here (and in the matching spot in UnrollTupleMaps) would be friendlier.

@tehrengruber tehrengruber left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following lowering strategy does not work right now:
(2*el for el in (a(V2E), b)) -> map_tuple(lambda el: op_as_fieldop(*)(2, el))({a_field_with_local_dim, b_field_of_scalars}) since the el symref node can only have one type and since we can not distinguish (i.e. no isinstance on GTIR level) between them.
However since we know the length we can use the process_elements approach like for where. In case we don't know the length the element type is always the same and we can use map_tuple in the lowering and then fold it on GTIR level.

(2*el for el in var_length_tuple_scalar) -> map_tuple(lambda el: 2*el)(var_length_tuple)
(2*el for el in var_length_tuple_of_field) -> map_tuple(lambda el: as_fieldop(lambda el_at_grid_point: 2*el_at_grid_point)(el))(var_length_tuple)
(2*el for el in var_length_tuple_field_with_local_dim) -> map_tuple(lambda el: map_list(lambda in_el: 2*in_el)(el))(tuple_of_it_of_list)

Conclusion:
We continue / merge this PR (in the previous form and after review). In #2487 we can then use map_tuple in the lowering for the variable length cases and only need process elements for the fixed length case. We also need tests to cover all cases (i.e. testing for where, concat_where, tuple comprehension of elements with mixed type like (a(V2E), b)).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants