Skip to content

[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974

Open
vthumbe1503 wants to merge 14 commits into
NVIDIA:mainfrom
vthumbe1503:fsdp2_dcp_laod_fix
Open

[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974
vthumbe1503 wants to merge 14 commits into
NVIDIA:mainfrom
vthumbe1503:fsdp2_dcp_laod_fix

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 commented May 11, 2026

Description

Fixes DCP Sync checkpoint loading for MXFP8/NVFP4.
Fixes DCP Async checkpoint loading for all Quantization recipes
Fixes NVFP4 allgather + dequant numerical errors for fsdp2. Turns out this was due to us not setting the fsdp group as the amax reduction group in the quantizer

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • DCP Sync Checkpoint loading

    • untyped_storage is now defined for the base QuantizedTensor to return empty storage. Untyped_storage refers to the backing storage that we use to create all the internal tensors. Since we use make_wrapper_subclass to create TE QuantizedTensors, we use dont have any backing storage associated with the tensor. data_ptr on our Custom QuantizedTensor also returns 0.
    • The main issue is that FSDP2 maintains sharded param tensor for checkpointing. It does so by calling view(-1) on our Quantized sharded model parameters. We return back a dequantized 1D tensor in TE. So, the sharded tensor that FSDP2 maintains for checkpointing is BF16 and Quantized sharded param is our custom FP8 tensor. It evaluates untyped_storage(BF16 sharded tensor reloaded from disk) == untyped_storage(Quantized sharded parameter) to see if the same_tensor. With us returning empty storage now, this would never be equal to sharded tensor's untyped storage.
  • DCP Async Checkpointing

    • to_new_empty function with device="cpu" is being used in Async Checkpointing. This function returned Quantizer.make_empty without setting the device. For device = "cpu" we now dequantize. So that the Async checkpointing directly saves the bf16 data on disk and reload works fine.
  • NVFP4 Allgather Correctness issues

    • Allgather with FSDP2 was very far away from fp32 allgather for the same values. This was due to us not setting the amax reduction group in the quantizer.
  • TE_DType Serialization issues with DCP Checkpointing

    • DCP uses torch.load(weights_only=True), whose Unpickler rejects every GLOBAL reference that isn't in add_safe_globals — and getattr is intentionally not allow-listed.
    • So we override the default enum reduction in pybind:
default:      (getattr, (tex.DType, "kFloat8E4M3"))   # needs getattr + tex.DType allow-listed
pybind override: (tex.DType, (int_value,))            # only needs tex.DType allow-listed

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@vthumbe1503 vthumbe1503 changed the title [Pytorch][Bug] DCP Load Fixes for FSDP2 with QuantizedModelInit [Pytorch][Bug] DCP Checkpoint Load Fixes for FSDP2 with QuantizedModelInit May 11, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 11, 2026

Greptile Summary

This PR fixes several issues with FSDP2 DCP checkpoint loading for MXFP8/NVFP4 quantized tensors: untyped_storage() now returns a zero-byte empty storage so FSDP2's same-tensor identity check never falsely matches a quantized parameter against its BF16 staging buffer; a new _to_copy dispatch handler moves all inner buffers while preserving the QuantizedTensor subclass; CPU dequantization for MXFP8/NVFP4 bounces through CUDA so tex.dequantize can operate; and pickle reconstruction functions are promoted to module-level so torch.load(weights_only=True) can reference them without needing getattr in its safe-globals list.

  • QuantizedTensor.untyped_storage() / _to_copy handler – returns a fresh zero-byte UntypedStorage on the tensor's device (fixes FSDP2 staging identity check), and a new aten._to_copy.default branch moves every inner tensor buffer to the target device while reconstructing the correct subclass.
  • CPU dequantization for MXFP8/NVFP4_FromMXFP8Func and _FromNVFP4Func detect CPU-resident tensors and temporarily move them to CUDA before calling tex.dequantize, returning the result back to the original device.
  • Pickle / DCP safe-globals__reduce_ex__ on every tensor type now points at a module-level function (not a classmethod), and tex.DType gets a custom __reduce_ex__ in the pybind11 binding that encodes as (tex.DType, (int,)), eliminating the getattr dependency.

Confidence Score: 5/5

The changes are targeted bug fixes with no functional regressions identified; the new _to_copy dispatch, empty-storage override, CPU dequant bounce, and module-level pickle functions all behave correctly.

All four fix areas (FSDP2 identity check, async staging, NVFP4 allgather amax group, DCP safe-globals) are addressed consistently across tensor types. The quantizer __getstate__ methods correctly exclude non-picklable process groups. The only findings are two documentation nits (a duplicated docstring phrase and a truncated comment), neither of which affects runtime behavior.

No files require special attention; the two documentation issues are in transformer_engine/pytorch/tensor/mxfp8_tensor.py and transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py.

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantized_tensor.py Adds untyped_storage() returning zero-byte storage, fixes cpu() to preserve QuantizedTensor type, adds _to_copy dispatch for device moves, and propagates device in _make_tensor.
transformer_engine/pytorch/init.py Registers module-level reconstruct functions and quantizer/storage types as safe globals for torch.load(weights_only=True); getattr is no longer in the list.
transformer_engine/common/util/pybind_helper.h Adds __reduce_ex__/__reduce__ overrides to the tex.DType pybind11 enum to serialize as (tex.DType, (int,)) instead of the unsafe (getattr, (tex.DType, name)) form.
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py CPU dequantization now bounces through CUDA: detects non-CUDA device, moves tensor with .to(device="cuda"), dequantizes, then moves result back.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py Same CPU-bounce pattern as MXFP8, but the inline comment at line 55 is truncated ("If the tensor has"). Functionally correct.
transformer_engine/pytorch/tensor/mxfp8_tensor.py Promotes __reduce_ex__ to module-level function; propagates device across all tensor construction sites. Docstring in _make_mxfp8_tensor_in_reduce_ex contains a duplicated phrase (copy-paste artifact).
transformer_engine/pytorch/tensor/nvfp4_tensor.py Promotes __reduce_ex__ to module-level, propagates device, adds FSDP2 amax-reduction-group fix. NVFP4Quantizer.__getstate__ correctly excludes the non-picklable process group.
transformer_engine/pytorch/tensor/float8_tensor.py Promotes reconstruct function to module level, removes CPU-dequantize fallback in __reduce_ex__ (CPU Float8Tensor already has a manual CPU fallback in _FromFloat8Func), propagates device across construction sites.
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py Promotes reconstruct function to module level, removes subclass untyped_storage() override (now handled by base class), propagates device to all construction sites.
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Removes device from get_metadata() to prevent it from overriding the target_device set by the new _to_copy handler in QuantizedTensor.__torch_dispatch__.
transformer_engine/pytorch/module/base.py Extends the amax_reduction_group assignment to NVFP4Quantizer in FSDP2/DTensor paths, fixing NVFP4 allgather numerical errors.
tests/pytorch/test_quantized_tensor.py Adds test_cpu_dequantize covering the new CPU-bounce dequantization path for all quantization formats with strict bitwise tolerance.

Reviews (8): Last reviewed commit: "Merge branch 'main' into fsdp2_dcp_laod_..." | Re-trigger Greptile

Comment on lines +536 to +545
def untyped_storage(self) -> torch.UntypedStorage:
"""Return an empty UntypedStorage on the tensor's device.

``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
backing storage of its own; the actual bytes live in the inner
buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
an implementation detail of the quantization scheme. Need to define
this method to avoid DCP staging errors with FSDP2.
"""
return torch.UntypedStorage(0, device=self.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Empty storage breaks shared-storage detection in existing callers

QuantizedTensor.untyped_storage() now returns a freshly allocated zero-byte storage every call. Code in module/_common.py:128 compares tensors[0].untyped_storage().nbytes() against expected size to decide between a no-op view and an out-of-place torch.cat. With 0 bytes returned, that condition is always true, silently disabling the in-place fast path for any QuantizedTensor through ConcatMerge.forward. More critically, utils.py:403-412 in SplitAlongDim.backward uses data_ptr() for noop detection — if all zero-size CUDA allocations return data_ptr() == 0, every QuantizedTensor pair incorrectly appears co-located, setting noop_ok = True and crashing on ret.set_() against a 0-byte storage.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, while I don't think we use QuantizedTensors in the SplitAlongDim ever, the concat sounds plausible to be hit.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to resolve this comment after going thoroughly over noop_cat consequences on Quantizedtensors

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior is unchanged with the change. And I would argue the implementation now is more correct with the change. untyped_storage() default implementation from QuantizedTensor(torch.Tensor) before this change, gives a storage with two properties.

  1. storage.nbytes() returns bytes based on the fake_dtype that we use to register our QuantizedTensor as a torchTensor using make_wrapper_subclass method of torch.

  2. storage.data_ptr() gives an error saying it is an invalid storage and there is no data_ptr()

Both of them is not ideal.
The first one is grossly incrorrect due to two reasons. First we manage the backing storage for the inner tensors of QuantizedTensor and torch has no idea about it. Second nbytes based on fake_dtype is misleading since that might not actually be the number of bytes we actually allocate.
Second one is causing problems with FSDP2 now since it expects some storage for identity check.

For QuantizedTensor, noop_cat today always returns an actual torch.cat which goes through a dequantization luckily due to this condition being true. This condition is going to be true now with the change as well since nbytes() would return 0.

If we do QuantizedTensor.data_ptr() today it gives you 0. QuantizedTensor.untyped_storage().data_ptr() will give invalid storage error which is inconsistent. And giving empty storage as empty storage will fix this inconsitency.

As far as idenity checking goes, FSDP2 does all the comparisong logic only if data_ptr() is not 0. And it also doesnt really make sense to compare two empty storages.

Comment thread tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py Outdated
Comment thread tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py Outdated
Comment thread tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py
@vthumbe1503 vthumbe1503 changed the title [Pytorch][Bug] DCP Checkpoint Load Fixes for FSDP2 with QuantizedModelInit [Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit May 11, 2026
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}",
)
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need dequant + quant here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are doing it anymore

Comment on lines +613 to +616
# When a CPU copy of a quantized tensor is requested (e.g. by
# torch DCP staging via ``x.new_empty(..., device="cpu")``), we
# save the high-precision values in a plain CPU dense tensor.
# For the DCP load path, we will re-quantize the high-precision values.
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix seems ad hoc to me. It's not obvious why qtensor.new_empty(..., device="cpu") returns a quantized tensor while qtensor.new_empty(..., device="cuda") returns a plain tensor. I wonder if it would be cleaner to just return a plain tensor in all cases. Thoughts:

  • It's uncomfortable how new_empty and empty_like would have different behavior. I suppose we could interpret empty_like as "make a tensor that matches the input" and new_empty as "call torch.empty with defaults taken from input", but that would be a private interpretation that no one else follows.
  • Would this affect FSDP or CPU offloading?
  • Given the weirdness, would it be worthwhile raising a warning if new_empty is called outside of DCP?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it was ugly. Now we have a new solution where we have implemented the to_copy function in torch dispatch. This allows for staging the inner tensors of QuantizedTensor on CPU in a blocking/non-blocking way for sync/async DCP checkpointing.

We only do this in to_copy if dtype is unchanged. Otherwise we still go through the dequantize route.

# torch DCP staging via ``x.new_empty(..., device="cpu")``), we
# save the high-precision values in a plain CPU dense tensor.
# For the DCP load path, we will re-quantize the high-precision values.
target_size = torch.Size(size) if len(size) > 0 else tensor.size()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An empty size is valid and it corresponds to a tensor with 1 entry (for the same reason 2^0=1).

>>> import torch
>>> x = torch.ones(123).new_empty([])
>>> print(x.numel())
1
Suggested change
target_size = torch.Size(size) if len(size) > 0 else tensor.size()
target_size = size

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the torch dispatch function now. So we dont have size here

Comment on lines +536 to +545
def untyped_storage(self) -> torch.UntypedStorage:
"""Return an empty UntypedStorage on the tensor's device.

``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
backing storage of its own; the actual bytes live in the inner
buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
an implementation detail of the quantization scheme. Need to define
this method to avoid DCP staging errors with FSDP2.
"""
return torch.UntypedStorage(0, device=self.device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.

Comment thread tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py Outdated
Comment thread transformer_engine/pytorch/quantized_tensor.py Outdated
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 force-pushed the fsdp2_dcp_laod_fix branch from 3589ffa to 4197bee Compare May 13, 2026 04:00
@vthumbe1503 vthumbe1503 requested a review from ksivaman as a code owner May 13, 2026 04:00
pre-commit-ci Bot and others added 2 commits May 13, 2026 04:01
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

Comment thread transformer_engine/pytorch/__init__.py Outdated
vthumbe1503 and others added 6 commits May 13, 2026 04:19
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Comment thread transformer_engine/pytorch/tensor/float8_tensor.py
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503 vthumbe1503 added the bug Something isn't working label May 18, 2026
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.16.0 bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants