Skip to content

[torchlib] Reimplement as_strided without an ONNX loop#2928

Open
Copilot wants to merge 8 commits into
mainfrom
copilot/torchlib-reimplement-as-strided
Open

[torchlib] Reimplement as_strided without an ONNX loop#2928
Copilot wants to merge 8 commits into
mainfrom
copilot/torchlib-reimplement-as-strided

Conversation

Copilot AI commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

aten_as_strided was lowered to a private ONNX function (_aten_as_strided_onnx) that built gather indices via an unrolled loop of Expand/Range/SequenceInsert/ConcatFromSequence ops. This graph is hard for downstream passes to constant-fold.

Since aten_as_strided is already trace_only, when size, stride, and storage_offset are concrete at trace time the indices can be computed once with NumPy and emitted as a constant. The SymInt inputs can also be dynamic (runtime values), so a second path builds the indices with ONNX ops while still avoiding any Loop/Scan.

Changes

  • ops/core.py: Replace the loop implementation with two paths sharing the same index math, where for each output position the storage index is storage_offset + Σ_d i_d · stride[d] and the result is Reshape(self, [-1]) + Gather:
    • Static fast path (all of size/stride/storage_offset are concrete ints): fold the indices into a single constant index tensor.
    • Dynamic path (any SymInt is a runtime value): build the indices with ONNX ops (Range/Mul/Unsqueeze/Add). The per-dimension contributions are unrolled at trace time since the rank is always static, so no Loop/Scan is emitted. Runtime SymInt values are assumed to be INT64 and reshaped to scalars directly, and mixed static/dynamic dimensions are supported.
    • A default storage_offset=None is normalized to 0 so the dynamic path does not emit an invalid Reshape of a missing input.
  • ops/core.py: Remove the now-unused private _aten_as_strided_onnx function.
  • deduce_type_constraints_test.py: Drop _aten_as_strided_onnx from _SKIP_FUNCTIONS_WITH_LOOP_OR_SCAN, since no loop/scan remains.
  • tests/function_libs/torch_lib/e2e_ops_tests.py: Add end-to-end coverage for both paths — static (multi-dimensional with non-zero storage_offset, single dimension, overlapping strides, scalar/empty size) and dynamic (size derived from the input shape, with and without storage_offset).

Implementation

# Static fast path
indices = np.array(storage_offset, dtype=np.int64)
for dim, (dim_size, dim_stride) in enumerate(zip(size, stride)):
    add_value = np.arange(dim_size, dtype=np.int64) * dim_stride
    broadcast_shape = [1] * len(size)
    broadcast_shape[dim] = dim_size
    indices = indices + add_value.reshape(broadcast_shape)
self_flatten = op.Reshape(self, op.Constant(value_ints=[-1]))
result = op.Gather(self_flatten, op.Constant(value=ir.tensor(indices)))

# Dynamic path (any SymInt is a runtime value; SymInts are assumed INT64)
indices = op.Reshape(storage_offset, empty_shape)
for dim in range(rank):
    dim_size = op.Reshape(size[dim], empty_shape)
    dim_stride = op.Reshape(stride[dim], empty_shape)
    add_value = op.Mul(op.Range(zero, dim_size, one), dim_stride)
    unsqueeze_axes = [axis for axis in range(rank) if axis != dim]
    if unsqueeze_axes:
        add_value = op.Unsqueeze(add_value, op.Constant(value_ints=unsqueeze_axes))
    indices = op.Add(indices, add_value)
result = op.Gather(self_flatten, indices)

The empty-size case naturally yields a 0-d index tensor, producing a scalar output. Both paths were checked against torch.as_strided for multi-dimensional, non-zero storage_offset, single-dimension, scalar/empty-size, and mixed static/dynamic inputs.

Copilot AI changed the title [WIP] Reimplement as_strided for efficiency [torchlib] Reimplement as_strided without an ONNX loop Jun 3, 2026
Copilot AI requested a review from justinchuby June 3, 2026 18:42
@justinchuby

Copy link
Copy Markdown
Collaborator

@copilot the inputs marked as SymInt can be dynamic. Can you update the logic to account for this situation?

Copilot AI commented Jun 3, 2026

Copy link
Copy Markdown
Contributor Author

@copilot the inputs marked as SymInt can be dynamic. Can you update the logic to account for this situation?

Added a dynamic path in 93d3247: when any of size/stride/storage_offset is a runtime value, the indices are built with ONNX ops (Range/Mul/Unsqueeze/Add) instead of a folded constant. The per-dimension contributions are still unrolled at trace time (no loop/scan) since the rank is static, and SymInt values are cast to INT64 for consistent arithmetic. The static fast path (all concrete) still folds to a single constant Gather. Verified both paths against torch.as_strided for multi-dim, non-zero storage_offset, scalar/empty-size, mixed static/dynamic, and int32/int64 SymInt inputs.

Comment thread onnxscript/function_libs/torch_lib/ops/core.py Outdated
@justinchuby

Copy link
Copy Markdown
Collaborator

@copilot can you add good coverage on the new logic in tests/function_libs/torch_lib/e2e_ops_tests.py?

@justinchuby

Copy link
Copy Markdown
Collaborator

@copilot also add more detailed comment to explain what each line of the implementation is doing in relation to the torch op, as it is a more complicated op, for readability

Copilot AI commented Jun 3, 2026

Copy link
Copy Markdown
Contributor Author

@copilot can you add good coverage on the new logic in tests/function_libs/torch_lib/e2e_ops_tests.py?

Added e2e coverage in tests/function_libs/torch_lib/e2e_ops_tests.py exercising both paths: static (multi-dim with non-zero storage_offset, single dim, overlapping strides, scalar/empty size) and dynamic (size derived from the input shape, with and without storage_offset). Writing the dynamic tests surfaced a bug where a default storage_offset=None reached the dynamic path and produced an invalid Reshape; fixed by normalizing None to 0. All tests pass against torch.as_strided.

Copilot AI commented Jun 3, 2026

Copy link
Copy Markdown
Contributor Author

@copilot also add more detailed comment to explain what each line of the implementation is doing in relation to the torch op, as it is a more complicated op, for readability

Added detailed line-by-line comments to aten_as_strided, explaining the storage-index math in relation to torch.as_strided and annotating each step of both the static and dynamic paths.

empty_shape = op.Constant(value=ir.tensor(np.array([], dtype=np.int64)))
# Start the running index from storage_offset, cast to an INT64 scalar so all
# the arithmetic below has a consistent dtype regardless of how the SymInt
# runtime values are typed (e.g. int32 SymInts).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot I think we can assume all symints are int64.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the INT64 Cast ops from the dynamic path and now reshape the SymInt values directly, assuming they are already INT64.

@justinchuby justinchuby requested a review from Copilot June 3, 2026 20:16
@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Jun 3, 2026

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the torchlib lowering for aten::as_strided to avoid generating an ONNX Loop/Scan when constructing gather indices, enabling simpler downstream constant-folding. It introduces a static fast path that precomputes the index tensor with NumPy when all size/stride/storage_offset values are known at trace time, and a dynamic path that builds the same index math using ONNX ops without loops.

Changes:

  • Reimplemented aten_as_strided in ops/core.py as a Reshape([-1]) + Gather with (1) a NumPy-constant index fast path and (2) an ONNX-op dynamic index path (no Loop/Scan).
  • Removed the now-unused private _aten_as_strided_onnx lowering and unblocked type-constraint deduction by removing it from the “skip loop/scan” list.
  • Added new E2E tests covering several as_strided scenarios (static and dynamic shapes/offsets).

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
tests/function_libs/torch_lib/e2e_ops_tests.py Adds E2E export coverage for torch.as_strided across static and dynamic cases.
onnxscript/function_libs/torch_lib/ops/core.py Replaces loop-based index construction with static NumPy-constant and dynamic ONNX-op paths.
onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py Removes _aten_as_strided_onnx from the loop/scan skip list since it no longer exists.

Comment thread onnxscript/function_libs/torch_lib/ops/core.py
@justinchuby justinchuby marked this pull request as ready for review June 9, 2026 18:49
@codecov

codecov Bot commented Jun 9, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 41.66667% with 14 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.66%. Comparing base (33d1445) to head (40f2bb1).
⚠️ Report is 1 commits behind head on main.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 41.66% 13 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2928      +/-   ##
==========================================
+ Coverage   72.63%   72.66%   +0.02%     
==========================================
  Files         259      259              
  Lines       31665    31662       -3     
  Branches     2981     2984       +3     
==========================================
+ Hits        22999    23006       +7     
+ Misses       7656     7645      -11     
- Partials     1010     1011       +1     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

@titaiwangms

Copy link
Copy Markdown
Contributor

Multi-reviewer summary (readability, code, critical, deep-semantic, integration)

Genuine improvement — replacing the un-foldable Loop with a single flatten + Gather is the right call. The index math (storage_offset + Σ_d i_d·stride[d]) was verified correct against torch.as_strided, including empty-size→scalar and overlapping-stride cases, on both the static NumPy path and the unrolled dynamic ONNX path. The private _aten_as_strided_onnx is cleanly removed and the deduce_type_constraints skip-list updated.

Major

  • Dynamic path assumes SymInt scalars are already INT64. zero/one are INT64 constants, but size[dim], stride[dim], and storage_offset are only Reshaped, not cast. An INT32 scalar from some export path would cause Range/Mul/Add type mismatches. Suggest an explicit Cast(..., to=INT64.dtype) on the dynamic size/stride/offset scalars.
  • Non-contiguous storage semantics. Reshape(self, [-1]) yields the logical row-major order, not torch's underlying storage order, so a non-contiguous self (e.g. a transposed view) would diverge. This matches the prior implementation's assumption and is normally guaranteed by dynamo decomposition, so it is a documented limitation rather than a regression — worth a one-line comment noting the contiguity assumption.

Minor

  • Static path can materialize huge constants. For large/unfold-like views the static path emits an int64 index tensor of shape size, which can bloat the model / blow up exporter memory. Consider a product-size threshold that falls back to the dynamic path.
  • Sibling consistency (out of scope for this PR). aten_as_strided_copy, aten_as_strided_scatter, aten_empty_strided, and aten_new_empty_strided still use the old INT64 / Sequence[int] hints rather than the new Sequence[INT64]. Possible follow-up.

Readability

  • empty_shape reads as "shape of an empty tensor" but is actually the reshape-to-scalar target; rename to scalar_shape.
  • The identical trailing comment appears in both branches — could be hoisted once before the if/else.

The architectural comment block explaining the storage-offset formula is excellent documentation.

@justinchuby

Copy link
Copy Markdown
Collaborator

Thanks. SymInts are always int64. Others I will fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

[torchlib] Reimplement as_strided

4 participants