Skip to content

Optimize function that loads pointers on GPU#3001

Open
timmoon10 wants to merge 11 commits into
NVIDIA:mainfrom
timmoon10:tmoon/optimize-get_device_pointer_for_data_and_scales
Open

Optimize function that loads pointers on GPU#3001
timmoon10 wants to merge 11 commits into
NVIDIA:mainfrom
timmoon10:tmoon/optimize-get_device_pointer_for_data_and_scales

Conversation

@timmoon10
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 commented May 16, 2026

Description

tex.get_device_pointer_for_data_and_scales has two problems:

  1. It has significant CPU overhead (see [PyTorch] Reduce CPU overhead in grouped MLP block #2897). In a representative benchmark on a GB200, it takes 85 us per call.
  2. The meaning is extremely unintuitive. The most natural interpretation is that it takes a FP8/MXFP8/NVFP4 tensor and returns pointers as two int s. But actually it takes the buffers from multiple MXFP8/NVFP4 tensors (all assumed to have the same shape), swizzles the scaling factors, and transfers the pointers to a GPU array in a CUDA Graph-friendly way.

This PR makes several optimizations to reduce CPU overhead, mostly to reduce unnecessary heap allocations. I've also attempted to make the functionality more general and logical:

  • nvte_load_value_on_device: A general function for copying a small amount of data to GPU in a CUDA Graph-friendly way. Unlike nvte_convert_pointers_to_tensor, it makes no assumptions that the data is a list of pointers.
  • tex.load_data_ptrs_on_device: Takes a list of tensors and puts their data pointers into a GPU buffer.
  • tex.transform_and_load_data_ptrs_on_device: Performs a user-provided transform on a list of tensors and puts the resulting data pointers into a GPU buffer. Currently it only supports scale swizzles on uniformly shaped tensors, but the transform names help make the contracts explicit.

With these changes, per-call CPU runtime has dropped from 85 us to 44 us. (Tentative, need to rerun benchmark).

This is progress toward #2897.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring
  • Performance optimization

Changes

  • Avoid heap allocation when converting NVTEBasicTensor to transformer_engine::SimpleTensor
  • Avoid heap allocation in transformer_engine::Tensor shape functions
  • Add transformer_engine::Tensor::flat_2d_dims to compute first and last dims simultaneously
  • Generalize and rename nvte_load_value_on_device
  • Refactor and rename tex.load_data_ptrs_on_device and tex.transform_and_load_data_ptrs_on_device

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

timmoon10 and others added 8 commits May 15, 2026 01:35
Avoid constructing temporary std::vector when converting NVTEBasicTensor to SimpleTensor. Avoid string operations in multi-tensor swizzle. Avoid temporary std::vector when checking scale tensors.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Tensor::shape() returns a std::vector<size_t> by value, allocating
on the heap. flat_first_dim and flat_last_dim only need to walk
the dims, so the allocation was pure overhead in hot paths.

Introduce Tensor::compute_shape() returning an NVTEShape (fixed
inline buffer, no heap) as the single source of truth for the
format-dependent shape logic. shape() is now a thin std::vector
wrapper around it for callers that want a vector; flat_first_dim
and flat_last_dim call compute_shape() directly.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
flat_first_dim() and flat_last_dim() each called compute_shape()
independently. flat_2d_dims() computes both in a single pass; the
scalar helpers now delegate to it.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Replace all paired flat_first_dim() + flat_last_dim() calls on the
same tensor with a single flat_2d_dims() call. Saves one compute_shape()
per tensor in CheckScaleTensorShape, the multi-tensor swizzle loop, and
various cast/GEMM dispatch paths.

Also adds reserve() to the local vectors in
nvte_multi_tensor_swizzle_scaling_factors to avoid reallocation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Replace the inline swizzle implementation with a call to
multi_tensor_swizzle_scales_for_gemm, which has identical logic
(16B-aligned contiguous output buffer, TensorWrapper construction,
nvte_multi_tensor_swizzle_scaling_factors kernel). Swizzled pointers
are read back from the updated TensorWrappers after the call.

Add reserve() to vectors in multi_tensor_swizzle_scales_for_gemm_impl
now that this function is on the hot path for get_device_pointer_for_data_and_scales.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 16, 2026

Greptile Summary

This PR replaces tex.get_device_pointer_for_data_and_scales and tex.convert_host_pointers_to_tensor with a cleaner, lower-overhead API (nvte_load_value_on_device, load_data_ptrs_on_device, transform_and_load_data_ptrs_on_device), and refactors Tensor::shape() to avoid heap allocations by returning NVTEShape internally.

  • Core kernel redesign: The new nvte_load_value_on_device passes host data through kernel arguments (up to 2 KB/launch) and writes via 32-bit vectorized stores, eliminating per-call malloc/free overhead that contributed to the 85 µs latency.
  • Shape helpers: compute_shape() returns a stack-allocated NVTEShape; flat_2d_dims() computes both dimensions in one pass; flat_first_dim() and flat_last_dim() are now thin wrappers, preserving backwards compatibility throughout the call tree.
  • Python API split: The Python callers in forward_grouped_mlp.py and backward_grouped_mlp.py now call load_data_ptrs_on_device for data pointers and transform_and_load_data_ptrs_on_device with explicit transform-type strings for swizzled scale pointers, with keepalive buffers properly captured.

Confidence Score: 4/5

Safe to merge; the performance-critical paths are correct and the new API is a clean improvement over the old one.

The refactoring is well-structured and the core kernel logic is sound. The one gap worth tracking is that transform_and_load_data_ptrs_on_device computes allocation stride from tensors[0] without asserting that all other tensors share the same shape — the uniform contract is documented in a comment but not enforced. All current callers pass homogeneous weight tensors so this does not affect present behavior, but the missing guard leaves a latent trap for future callers.

transformer_engine/pytorch/csrc/extensions/utils.cpp — the uniform-shape assumption in transform_and_load_data_ptrs_on_device is undocumented at the API boundary and unvalidated at runtime.

Important Files Changed

Filename Overview
transformer_engine/common/util/utils.cu Replaces the old pointer-copy kernel with a general nvte_load_value_on_device that chunks arbitrary host data (up to 2 KB at a time) into kernel args and writes via uint32_t vectors; deprecated wrapper delegates correctly. Logic is sound.
transformer_engine/common/common.h Refactors Tensor::shape() into compute_shape() (returns NVTEShape to avoid heap allocation) and adds flat_2d_dims() returning std::array<size_t,2>; wrapper methods flat_first_dim()/flat_last_dim() preserved. Behavioral parity maintained.
transformer_engine/pytorch/csrc/extensions/utils.cpp New load_data_ptrs_on_device and transform_and_load_data_ptrs_on_device replace old monolithic helpers; uniform-stride allocation in the swizzle path assumes all input tensors share the same scale shape but does not validate this invariant explicitly.
transformer_engine/common/swizzle/swizzle.cu Minor cleanups: get_max_dynamic_smem now uses a lambda+static idiom, several reserve() calls added, flat_2d_dims() replaces paired flat_first_dim/flat_last_dim calls. Error messages lose per-tensor index but functionality is unchanged.
transformer_engine/common/transformer_engine.cpp Scale-shape validation rewritten to use std::array + flat_2d_dims() for MXFP8 and NVFP4 paths; both rowwise and columnwise error messages now correctly print the appropriate shape field.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Updated call sites: get_device_pointer_for_data_and_scales split into load_data_ptrs_on_device + transform_and_load_data_ptrs_on_device with explicit transform-type strings; keepalive buffers properly captured.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Same call-site migration as forward path; keepalive buffer variables correctly prevent premature deallocation during async GPU execution.

Reviews (3): Last reviewed commit: "Formatter and review suggestions from @g..." | Re-trigger Greptile

Comment thread transformer_engine/common/transformer_engine.cpp
Comment thread transformer_engine/common/util/utils.cu
dtype(static_cast<DType>(tensor.dtype)) {}

SimpleTensor() : SimpleTensor(nullptr, std::vector<size_t>{0}, DType::kFloat32) {}
SimpleTensor &operator=(const NVTEBasicTensor &tensor) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this assignment operator, assigning from a NVTEBasicTensor triggers a heap allocator in the NVTEBasicTensor constructor. We do this assignment frequently within nvte_set_tensor_param_v2.

Comment thread transformer_engine/common/util/utils.cu Outdated
NVTE_CHECK(data_tensors[0].is_cuda(), "data_tensors must be on CUDA.");
const auto device = data_tensors[0].device();
auto stream = at::cuda::getCurrentCUDAStream();
std::tuple<at::Tensor, std::optional<at::Tensor>> transform_and_load_data_ptrs_on_device(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not committed to this name. I based it on std::transform. I suppose "map" would be more Python-focused, but that sounds worse.

Comment thread transformer_engine/common/util/utils.cu
Comment thread transformer_engine/common/transformer_engine.cpp Outdated
timmoon10 and others added 3 commits May 16, 2026 11:49
- Use size_t in kernel tail loop (was int64_t)
- Zero-initialize Payload before memcpy (Payload{})
- Rename Payload members to kMaxBytes/kVectorSize/kMaxVectors (linter)
- Consistent at::empty shape pattern: {static_cast<int64_t>(N)}
- Drop intermediate swizzled_scales_bytes variable
- Add comment explaining uniform-stride assumption in
  transform_and_load_data_ptrs_on_device
- Rename sfb_buffer -> _sfb_buffer (keepalive, not directly used)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 force-pushed the tmoon/optimize-get_device_pointer_for_data_and_scales branch from 7946e5d to 48cc585 Compare May 16, 2026 11:53
@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant