Skip to content

[PyTorch Debug] Fix scale_inv_min returning 0 for MXFP8/NVFP4#3041

Open
pggPL wants to merge 3 commits into
NVIDIA:mainfrom
pggPL:fix_scale_inv_min_padding
Open

[PyTorch Debug] Fix scale_inv_min returning 0 for MXFP8/NVFP4#3041
pggPL wants to merge 3 commits into
NVIDIA:mainfrom
pggPL:fix_scale_inv_min_padding

Conversation

@pggPL
Copy link
Copy Markdown
Collaborator

@pggPL pggPL commented May 25, 2026

Description

The debug API stat scale_inv_min was always reporting 0 for MXFP8 and NVFP4 recipes.

Root cause: MXFP8Quantizer.get_scale_shape and NVFP4Quantizer.get_scale_shape round the scale_inv shape up to multiples of 128 along one axis and 4 along the other:

  • MXFP8 rowwise: [round_up(M, 128), round_up(N/32, 4)]
  • MXFP8 columnwise: [round_up(M/32, 4), round_up(N, 128)]
  • NVFP4 rowwise: [round_up(M, 128), round_up(⌈K/16⌉, 4)]
  • NVFP4 columnwise: [round_up(K, 128), round_up(⌈M/16⌉, 4)]

The padded buffer is allocated upfront at the padded shape (at::empty(padded_shape) in quantizer.cpp), and the cast kernel issues cudaMemsetAsync(scale_inv, 0, ...) over the whole buffer before filling the valid region (see quantize_mxfp8.cuh:780-794), so padded
slots end up as literal zeros. add_scale_inv_stats in stats_computation.py was calling .min() over the whole padded tensor, which therefore always returned 0 whenever the unpadded scale shape did not already meet the alignment.

A real scale_inv entry is never zero: compute_scale_from_amax (recipe_common.cuh) returns scale = 1.0 for all-zero data blocks and clamps inf to a finite fallback, so zero in scale_inv uniquely identifies padding. Masking zeros before .min() is therefore
exact, not heuristic. .max() is unaffected (zero cannot be the maximum of a non-negative tensor) and is left as-is.

Fixes #2628

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add nonzero_min helper in add_scale_inv_stats (transformer_engine/debug/features/utils/stats_computation.py) that masks out padding zeros before computing the minimum of scale_inv.
  • Route the scale_inv_min per-step lambda through nonzero_min; scale_inv_max is unchanged.
  • Degenerate all-zero case returns a 0 scalar of the right dtype/device so the buffer aggregator (min(_get(...))) stays well-defined.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

MXFP8/NVFP4 quantizers pad scale_inv to multiples of [128, 4] (or
[4, 128] columnwise) with zeros, so a plain .min() over the whole
tensor was always returning 0. Mask zeros out before computing the
minimum.

Fixes NVIDIA#2628

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 25, 2026

Greptile Summary

This PR fixes a debug-stat bug where scale_inv_min always reported 0 for MXFP8/NVFP4 recipes because those quantizers pad scale_inv to tile-aligned shapes with zeros, and .min() was called over the full padded tensor.

  • Introduces a nonzero_min helper that masks scale_inv == 0 entries before calling .min(), relying on the guarantee that real scale-inverse values are never zero.
  • The degenerate all-zeros fallback returns a zero scalar to keep the buffer aggregation lambda (min(_get(...))) well-typed; the PR notes this path is unreachable in practice.
  • scale_inv_max is left unchanged, which is correct since padding zeros cannot be the maximum of a non-negative tensor.

Confidence Score: 5/5

Safe to merge — the change is a single-function addition with a well-reasoned invariant, and the rest of the stat-registration logic is untouched.

The fix is minimal and correct: it filters exactly the values that padding introduces (literal zeros), the invariant that no real scale-inverse is zero is verified by the comment tracing it to compute_scale_from_amax, and the fallback for the all-zero case is a no-op in practice. No existing call sites or stat aggregation contracts are broken.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/debug/features/utils/stats_computation.py Adds nonzero_min helper inside add_scale_inv_stats to mask out padding zeros before computing scale_inv_min; fixes always-zero stat for MXFP8/NVFP4 recipes.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["add_scale_inv_stats() called"] --> B["get_scale_inv()"]
    B --> C{"Has _scale_inv attr?"}
    C -- Yes --> D["Return _scale_inv"]
    C -- No --> E{"columnwise?"}
    E -- Yes --> F["Return _columnwise_scale_inv"]
    E -- No --> G["Return _rowwise_scale_inv"]

    D & F & G --> H["nonzero_min(scale_inv)"]
    H --> I["nz = scale_inv[scale_inv != 0]"]
    I --> J{"nz.numel() == 0?"}
    J -- Yes (unreachable) --> K["return new_zeros(()) scalar"]
    J -- No --> L["return nz.min()"]

    L & K --> M["Per-step lambda result"]
    M --> N["Buffer aggregator: min(_get(buffers, stat_name_min))"]
    N --> O["scale_inv_min stat"]
Loading

Reviews (3): Last reviewed commit: "Clarify scale_inv padding comment" | Re-trigger Greptile

Comment thread transformer_engine/debug/features/utils/stats_computation.py
pre-commit-ci Bot and others added 2 commits May 25, 2026 09:18
The previous wording said the padding was always [128, 4] / [4, 128],
which is true for MXFP8 but inaccurate for NVFP4 columnwise (padded to
[128, 4], not [4, 128]). Also note that scale_inv is never naturally 0
(compute_scale_from_amax returns 1.0 for all-zero blocks), so masking
zeros is exact rather than heuristic.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented May 25, 2026

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

debug api scale_inv_min is always zero when using padding or divisibility requirements

1 participant