[PyTorch debug] FakeQuant: support Float8BlockScaling and fix MoE / w…#3040
[PyTorch debug] FakeQuant: support Float8BlockScaling and fix MoE / w…#3040shangxiaokang wants to merge 2 commits into
Conversation
…eight-cache paths
for more information, see https://pre-commit.ci
Greptile SummaryThis PR extends
Confidence Score: 3/5The weight-cache write-back path works correctly after the api.py fix, but modify_tensor will crash with an AttributeError whenever out and dtype are both provided. The modify_tensor body calls .to(dtype) on the None returned by fake_quantize when out is non-None, crashing for any caller that passes both arguments simultaneously. MXFP8 formats now also route through zero-padding, but that path's correctness has only been argued for Float8BlockQuantizer. transformer_engine/debug/features/fake_quant.py - specifically the modify_tensor dtype handling and the MXFP8 entries in _FORMAT_DISPATCH. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[fake_quantize called] --> B{format in dispatch?}
B -- No --> C[raise ValueError]
B -- Yes --> D[unpack factory and block_size]
D --> E{block_size not None?}
E -- Yes --> F[_check_blockwise_shape]
F --> G[_pad_for_blockwise]
G --> H[build quantizer]
E -- No --> H
H --> I[dequantized = quantize then dequantize]
I --> J{padding applied?}
J -- Yes --> K[slice and reshape]
K --> L{out not None?}
J -- No --> L
L -- Yes --> M{out has quantize_?}
M -- Yes --> N[out.quantize_ dequantized]
M -- No --> O[out.copy_ dequantized]
N --> P[return None]
O --> P
L -- No --> Q[return dequantized]
|
| "MXFP8E5M2": (_build_mxfp8_quantizer, tex.DType.kFloat8E5M2, {}, MXFP8_BLOCK_SCALING_SIZE), | ||
| # Float8 blockwise: 2D 128x128 tiles | ||
| "FP8_BLOCKWISE_E4M3": ( |
There was a problem hiding this comment.
The docstring states "MXFP8*:
shape[-1] and prod(shape[:-1]) must both be divisible by 32", but because block_size is non-None for MXFP8 entries, _pad_for_blockwise silently zero-pads the leading dim when it is not 32-aligned. The _pad_for_blockwise docstring explicitly describes Float8BlockQuantizer's clean zero-block behaviour, but does not guarantee the same for MXFP8Quantizer. If MXFP8Quantizer does not handle padded rows the same way, the output for the non-padded slice can be subtly wrong without any error.
| "MXFP8E5M2": (_build_mxfp8_quantizer, tex.DType.kFloat8E5M2, {}, MXFP8_BLOCK_SCALING_SIZE), | |
| # Float8 blockwise: 2D 128x128 tiles | |
| "FP8_BLOCKWISE_E4M3": ( | |
| # MXFP8 (1x32 block scaling) - block_size=None: shape check and padding are | |
| # NOT applied because MXFP8Quantizer does not guarantee clean zero-block | |
| # behaviour for padded rows. Both dims must already be 32-aligned (caller | |
| # responsibility, same as the previous implementation). | |
| "MXFP8E4M3": (_build_mxfp8_quantizer, tex.DType.kFloat8E4M3, {}, None), | |
| "MXFP8E5M2": (_build_mxfp8_quantizer, tex.DType.kFloat8E5M2, {}, None), |
…eight-cache paths
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: