[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974
[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974vthumbe1503 wants to merge 14 commits into
Conversation
Greptile SummaryThis PR fixes several issues with FSDP2 DCP checkpoint loading for MXFP8/NVFP4 quantized tensors:
Confidence Score: 5/5The changes are targeted bug fixes with no functional regressions identified; the new All four fix areas (FSDP2 identity check, async staging, NVFP4 allgather amax group, DCP safe-globals) are addressed consistently across tensor types. The quantizer No files require special attention; the two documentation issues are in Important Files Changed
Reviews (8): Last reviewed commit: "Merge branch 'main' into fsdp2_dcp_laod_..." | Re-trigger Greptile |
| def untyped_storage(self) -> torch.UntypedStorage: | ||
| """Return an empty UntypedStorage on the tensor's device. | ||
|
|
||
| ``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real | ||
| backing storage of its own; the actual bytes live in the inner | ||
| buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are | ||
| an implementation detail of the quantization scheme. Need to define | ||
| this method to avoid DCP staging errors with FSDP2. | ||
| """ | ||
| return torch.UntypedStorage(0, device=self.device) |
There was a problem hiding this comment.
Empty storage breaks shared-storage detection in existing callers
QuantizedTensor.untyped_storage() now returns a freshly allocated zero-byte storage every call. Code in module/_common.py:128 compares tensors[0].untyped_storage().nbytes() against expected size to decide between a no-op view and an out-of-place torch.cat. With 0 bytes returned, that condition is always true, silently disabling the in-place fast path for any QuantizedTensor through ConcatMerge.forward. More critically, utils.py:403-412 in SplitAlongDim.backward uses data_ptr() for noop detection — if all zero-size CUDA allocations return data_ptr() == 0, every QuantizedTensor pair incorrectly appears co-located, setting noop_ok = True and crashing on ret.set_() against a 0-byte storage.
There was a problem hiding this comment.
The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.
There was a problem hiding this comment.
Yeah, while I don't think we use QuantizedTensors in the SplitAlongDim ever, the concat sounds plausible to be hit.
There was a problem hiding this comment.
Need to resolve this comment after going thoroughly over noop_cat consequences on Quantizedtensors
There was a problem hiding this comment.
The behavior is unchanged with the change. And I would argue the implementation now is more correct with the change. untyped_storage() default implementation from QuantizedTensor(torch.Tensor) before this change, gives a storage with two properties.
-
storage.nbytes() returns bytes based on the fake_dtype that we use to register our QuantizedTensor as a torchTensor using make_wrapper_subclass method of torch.
-
storage.data_ptr() gives an error saying it is an invalid storage and there is no data_ptr()
Both of them is not ideal.
The first one is grossly incrorrect due to two reasons. First we manage the backing storage for the inner tensors of QuantizedTensor and torch has no idea about it. Second nbytes based on fake_dtype is misleading since that might not actually be the number of bytes we actually allocate.
Second one is causing problems with FSDP2 now since it expects some storage for identity check.
For QuantizedTensor, noop_cat today always returns an actual torch.cat which goes through a dequantization luckily due to this condition being true. This condition is going to be true now with the change as well since nbytes() would return 0.
If we do QuantizedTensor.data_ptr() today it gives you 0. QuantizedTensor.untyped_storage().data_ptr() will give invalid storage error which is inconsistent. And giving empty storage as empty storage will fix this inconsitency.
As far as idenity checking goes, FSDP2 does all the comparisong logic only if data_ptr() is not 0. And it also doesnt really make sense to compare two empty storages.
|
/te-ci L1 pytorch |
| msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}", | ||
| ) | ||
| elif recipe_name == "NVFP4BlockScaling": | ||
| # NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances |
There was a problem hiding this comment.
Why do we need dequant + quant here?
There was a problem hiding this comment.
We are doing it anymore
| # When a CPU copy of a quantized tensor is requested (e.g. by | ||
| # torch DCP staging via ``x.new_empty(..., device="cpu")``), we | ||
| # save the high-precision values in a plain CPU dense tensor. | ||
| # For the DCP load path, we will re-quantize the high-precision values. |
There was a problem hiding this comment.
This fix seems ad hoc to me. It's not obvious why qtensor.new_empty(..., device="cpu") returns a quantized tensor while qtensor.new_empty(..., device="cuda") returns a plain tensor. I wonder if it would be cleaner to just return a plain tensor in all cases. Thoughts:
- It's uncomfortable how
new_emptyandempty_likewould have different behavior. I suppose we could interpretempty_likeas "make a tensor that matches the input" andnew_emptyas "call torch.empty with defaults taken from input", but that would be a private interpretation that no one else follows. - Would this affect FSDP or CPU offloading?
- Given the weirdness, would it be worthwhile raising a warning if
new_emptyis called outside of DCP?
There was a problem hiding this comment.
I agree it was ugly. Now we have a new solution where we have implemented the to_copy function in torch dispatch. This allows for staging the inner tensors of QuantizedTensor on CPU in a blocking/non-blocking way for sync/async DCP checkpointing.
We only do this in to_copy if dtype is unchanged. Otherwise we still go through the dequantize route.
| # torch DCP staging via ``x.new_empty(..., device="cpu")``), we | ||
| # save the high-precision values in a plain CPU dense tensor. | ||
| # For the DCP load path, we will re-quantize the high-precision values. | ||
| target_size = torch.Size(size) if len(size) > 0 else tensor.size() |
There was a problem hiding this comment.
An empty size is valid and it corresponds to a tensor with 1 entry (for the same reason 2^0=1).
>>> import torch
>>> x = torch.ones(123).new_empty([])
>>> print(x.numel())
1
| target_size = torch.Size(size) if len(size) > 0 else tensor.size() | |
| target_size = size |
There was a problem hiding this comment.
Changed the torch dispatch function now. So we dont have size here
| def untyped_storage(self) -> torch.UntypedStorage: | ||
| """Return an empty UntypedStorage on the tensor's device. | ||
|
|
||
| ``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real | ||
| backing storage of its own; the actual bytes live in the inner | ||
| buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are | ||
| an implementation detail of the quantization scheme. Need to define | ||
| this method to avoid DCP staging errors with FSDP2. | ||
| """ | ||
| return torch.UntypedStorage(0, device=self.device) |
There was a problem hiding this comment.
The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
3589ffa to
4197bee
Compare
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
Description
Fixes DCP Sync checkpoint loading for MXFP8/NVFP4.
Fixes DCP Async checkpoint loading for all Quantization recipes
Fixes NVFP4 allgather + dequant numerical errors for fsdp2. Turns out this was due to us not setting the fsdp group as the amax reduction group in the quantizer
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
DCP Sync Checkpoint loading
DCP Async Checkpointing
NVFP4 Allgather Correctness issues
TE_DType Serialization issues with DCP Checkpointing
Checklist: