[PyTorch Debug] Fix scale_inv_min returning 0 for MXFP8/NVFP4#3041
[PyTorch Debug] Fix scale_inv_min returning 0 for MXFP8/NVFP4#3041pggPL wants to merge 3 commits into
Conversation
MXFP8/NVFP4 quantizers pad scale_inv to multiples of [128, 4] (or [4, 128] columnwise) with zeros, so a plain .min() over the whole tensor was always returning 0. Mask zeros out before computing the minimum. Fixes NVIDIA#2628 Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Greptile SummaryThis PR fixes a debug-stat bug where
Confidence Score: 5/5Safe to merge — the change is a single-function addition with a well-reasoned invariant, and the rest of the stat-registration logic is untouched. The fix is minimal and correct: it filters exactly the values that padding introduces (literal zeros), the invariant that no real scale-inverse is zero is verified by the comment tracing it to compute_scale_from_amax, and the fallback for the all-zero case is a no-op in practice. No existing call sites or stat aggregation contracts are broken. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["add_scale_inv_stats() called"] --> B["get_scale_inv()"]
B --> C{"Has _scale_inv attr?"}
C -- Yes --> D["Return _scale_inv"]
C -- No --> E{"columnwise?"}
E -- Yes --> F["Return _columnwise_scale_inv"]
E -- No --> G["Return _rowwise_scale_inv"]
D & F & G --> H["nonzero_min(scale_inv)"]
H --> I["nz = scale_inv[scale_inv != 0]"]
I --> J{"nz.numel() == 0?"}
J -- Yes (unreachable) --> K["return new_zeros(()) scalar"]
J -- No --> L["return nz.min()"]
L & K --> M["Per-step lambda result"]
M --> N["Buffer aggregator: min(_get(buffers, stat_name_min))"]
N --> O["scale_inv_min stat"]
Reviews (3): Last reviewed commit: "Clarify scale_inv padding comment" | Re-trigger Greptile |
for more information, see https://pre-commit.ci
The previous wording said the padding was always [128, 4] / [4, 128], which is true for MXFP8 but inaccurate for NVFP4 columnwise (padded to [128, 4], not [4, 128]). Also note that scale_inv is never naturally 0 (compute_scale_from_amax returns 1.0 for all-zero blocks), so masking zeros is exact rather than heuristic. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
|
/te-ci pytorch |
Description
The debug API stat
scale_inv_minwas always reporting0for MXFP8 and NVFP4 recipes.Root cause:
MXFP8Quantizer.get_scale_shapeandNVFP4Quantizer.get_scale_shaperound thescale_invshape up to multiples of 128 along one axis and 4 along the other:[round_up(M, 128), round_up(N/32, 4)][round_up(M/32, 4), round_up(N, 128)][round_up(M, 128), round_up(⌈K/16⌉, 4)][round_up(K, 128), round_up(⌈M/16⌉, 4)]The padded buffer is allocated upfront at the padded shape (
at::empty(padded_shape)inquantizer.cpp), and the cast kernel issuescudaMemsetAsync(scale_inv, 0, ...)over the whole buffer before filling the valid region (seequantize_mxfp8.cuh:780-794), so paddedslots end up as literal zeros.
add_scale_inv_statsinstats_computation.pywas calling.min()over the whole padded tensor, which therefore always returned0whenever the unpadded scale shape did not already meet the alignment.A real
scale_inventry is never zero:compute_scale_from_amax(recipe_common.cuh) returnsscale = 1.0for all-zero data blocks and clamps inf to a finite fallback, so zero inscale_invuniquely identifies padding. Masking zeros before.min()is thereforeexact, not heuristic.
.max()is unaffected (zero cannot be the maximum of a non-negative tensor) and is left as-is.Fixes #2628
Type of change
Changes
Please list the changes introduced in this PR:
nonzero_minhelper inadd_scale_inv_stats(transformer_engine/debug/features/utils/stats_computation.py) that masks out padding zeros before computing the minimum ofscale_inv.scale_inv_minper-step lambda throughnonzero_min;scale_inv_maxis unchanged.0scalar of the right dtype/device so the buffer aggregator (min(_get(...))) stays well-defined.Checklist: