Optimize function that loads pointers on GPU#3001
Conversation
Avoid constructing temporary std::vector when converting NVTEBasicTensor to SimpleTensor. Avoid string operations in multi-tensor swizzle. Avoid temporary std::vector when checking scale tensors. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Tensor::shape() returns a std::vector<size_t> by value, allocating on the heap. flat_first_dim and flat_last_dim only need to walk the dims, so the allocation was pure overhead in hot paths. Introduce Tensor::compute_shape() returning an NVTEShape (fixed inline buffer, no heap) as the single source of truth for the format-dependent shape logic. shape() is now a thin std::vector wrapper around it for callers that want a vector; flat_first_dim and flat_last_dim call compute_shape() directly. Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
flat_first_dim() and flat_last_dim() each called compute_shape() independently. flat_2d_dims() computes both in a single pass; the scalar helpers now delegate to it. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
Replace all paired flat_first_dim() + flat_last_dim() calls on the same tensor with a single flat_2d_dims() call. Saves one compute_shape() per tensor in CheckScaleTensorShape, the multi-tensor swizzle loop, and various cast/GEMM dispatch paths. Also adds reserve() to the local vectors in nvte_multi_tensor_swizzle_scaling_factors to avoid reallocation. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Replace the inline swizzle implementation with a call to multi_tensor_swizzle_scales_for_gemm, which has identical logic (16B-aligned contiguous output buffer, TensorWrapper construction, nvte_multi_tensor_swizzle_scaling_factors kernel). Swizzled pointers are read back from the updated TensorWrappers after the call. Add reserve() to vectors in multi_tensor_swizzle_scales_for_gemm_impl now that this function is on the hot path for get_device_pointer_for_data_and_scales. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Greptile SummaryThis PR replaces
Confidence Score: 4/5Safe to merge; the performance-critical paths are correct and the new API is a clean improvement over the old one. The refactoring is well-structured and the core kernel logic is sound. The one gap worth tracking is that transform_and_load_data_ptrs_on_device computes allocation stride from tensors[0] without asserting that all other tensors share the same shape — the uniform contract is documented in a comment but not enforced. All current callers pass homogeneous weight tensors so this does not affect present behavior, but the missing guard leaves a latent trap for future callers. transformer_engine/pytorch/csrc/extensions/utils.cpp — the uniform-shape assumption in transform_and_load_data_ptrs_on_device is undocumented at the API boundary and unvalidated at runtime. Important Files Changed
Reviews (3): Last reviewed commit: "Formatter and review suggestions from @g..." | Re-trigger Greptile |
| dtype(static_cast<DType>(tensor.dtype)) {} | ||
|
|
||
| SimpleTensor() : SimpleTensor(nullptr, std::vector<size_t>{0}, DType::kFloat32) {} | ||
| SimpleTensor &operator=(const NVTEBasicTensor &tensor) { |
There was a problem hiding this comment.
Without this assignment operator, assigning from a NVTEBasicTensor triggers a heap allocator in the NVTEBasicTensor constructor. We do this assignment frequently within nvte_set_tensor_param_v2.
| NVTE_CHECK(data_tensors[0].is_cuda(), "data_tensors must be on CUDA."); | ||
| const auto device = data_tensors[0].device(); | ||
| auto stream = at::cuda::getCurrentCUDAStream(); | ||
| std::tuple<at::Tensor, std::optional<at::Tensor>> transform_and_load_data_ptrs_on_device( |
There was a problem hiding this comment.
I'm not committed to this name. I based it on std::transform. I suppose "map" would be more Python-focused, but that sounds worse.
- Use size_t in kernel tail loop (was int64_t)
- Zero-initialize Payload before memcpy (Payload{})
- Rename Payload members to kMaxBytes/kVectorSize/kMaxVectors (linter)
- Consistent at::empty shape pattern: {static_cast<int64_t>(N)}
- Drop intermediate swizzled_scales_bytes variable
- Add comment explaining uniform-stride assumption in
transform_and_load_data_ptrs_on_device
- Rename sfb_buffer -> _sfb_buffer (keepalive, not directly used)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
7946e5d to
48cc585
Compare
|
/te-ci |
Description
tex.get_device_pointer_for_data_and_scaleshas two problems:ints. But actually it takes the buffers from multiple MXFP8/NVFP4 tensors (all assumed to have the same shape), swizzles the scaling factors, and transfers the pointers to a GPU array in a CUDA Graph-friendly way.This PR makes several optimizations to reduce CPU overhead, mostly to reduce unnecessary heap allocations. I've also attempted to make the functionality more general and logical:
nvte_load_value_on_device: A general function for copying a small amount of data to GPU in a CUDA Graph-friendly way. Unlikenvte_convert_pointers_to_tensor, it makes no assumptions that the data is a list of pointers.tex.load_data_ptrs_on_device: Takes a list of tensors and puts their data pointers into a GPU buffer.tex.transform_and_load_data_ptrs_on_device: Performs a user-provided transform on a list of tensors and puts the resulting data pointers into a GPU buffer. Currently it only supports scale swizzles on uniformly shaped tensors, but the transform names help make the contracts explicit.With these changes, per-call CPU runtime has dropped from 85 us to 44 us. (Tentative, need to rerun benchmark).
This is progress toward #2897.
Type of change
Changes
NVTEBasicTensortotransformer_engine::SimpleTensortransformer_engine::Tensorshape functionstransformer_engine::Tensor::flat_2d_dimsto compute first and last dims simultaneouslynvte_load_value_on_devicetex.load_data_ptrs_on_deviceandtex.transform_and_load_data_ptrs_on_deviceChecklist: