Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion transformer_engine/debug/features/utils/stats_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,21 @@ def get_scale_inv(quantized_tensor, columnwise):
return getattr(quantized_tensor, "_columnwise_scale_inv")
return getattr(quantized_tensor, "_rowwise_scale_inv")

def nonzero_min(scale_inv):
# MXFP8/NVFP4 quantizers round the scale_inv shape up to multiples of
# 128 along one axis and 4 along the other and fill the extra slots
# with zeros (via torch.nn.functional.pad with the default value=0),
# so a plain .min() always returns 0 for shapes that needed padding.
# A real scale_inv entry is never 0: compute_scale_from_amax returns
# scale=1.0 for all-zero blocks and clamps the inf case to a finite
# fallback, so zeros uniquely identify padding and masking them out
# gives the true minimum. The empty-after-mask branch is a safety
# net for the (in practice unreachable) all-zero tensor.
nz = scale_inv[scale_inv != 0]
if nz.numel() == 0:
return scale_inv.new_zeros(())
return nz.min()
Comment thread
pggPL marked this conversation as resolved.

columnwise_suffix = "_columnwise" if columnwise else ""
# Prepare stat names.
stat_name_min = (
Expand All @@ -363,7 +378,9 @@ def get_scale_inv(quantized_tensor, columnwise):

# Capture the attribute name inside lambdas via default args to avoid late binding.
STATS[stat_name_min] = (
lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).min(),
lambda x, aux_dict, _col=columnwise: nonzero_min(
get_scale_inv(aux_dict[recipe_name], _col)
),
lambda buffers, _sn=stat_name_min: min(_get(buffers, _sn)),
)
STATS[stat_name_max] = (
Expand Down
Loading