diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index b0002ffee6..6668400017 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -348,6 +348,21 @@ def get_scale_inv(quantized_tensor, columnwise): return getattr(quantized_tensor, "_columnwise_scale_inv") return getattr(quantized_tensor, "_rowwise_scale_inv") + def nonzero_min(scale_inv): + # MXFP8/NVFP4 quantizers round the scale_inv shape up to multiples of + # 128 along one axis and 4 along the other and fill the extra slots + # with zeros (via torch.nn.functional.pad with the default value=0), + # so a plain .min() always returns 0 for shapes that needed padding. + # A real scale_inv entry is never 0: compute_scale_from_amax returns + # scale=1.0 for all-zero blocks and clamps the inf case to a finite + # fallback, so zeros uniquely identify padding and masking them out + # gives the true minimum. The empty-after-mask branch is a safety + # net for the (in practice unreachable) all-zero tensor. + nz = scale_inv[scale_inv != 0] + if nz.numel() == 0: + return scale_inv.new_zeros(()) + return nz.min() + columnwise_suffix = "_columnwise" if columnwise else "" # Prepare stat names. stat_name_min = ( @@ -363,7 +378,9 @@ def get_scale_inv(quantized_tensor, columnwise): # Capture the attribute name inside lambdas via default args to avoid late binding. STATS[stat_name_min] = ( - lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).min(), + lambda x, aux_dict, _col=columnwise: nonzero_min( + get_scale_inv(aux_dict[recipe_name], _col) + ), lambda buffers, _sn=stat_name_min: min(_get(buffers, _sn)), ) STATS[stat_name_max] = (