Skip to content

Add fused_adam, quantized_model_init, and fsdp2 example#2698

Open
pstjohn wants to merge 1 commit intoNVIDIA:mainfrom
pstjohn:pstjohn/fsdp2-fused-adam
Open

Add fused_adam, quantized_model_init, and fsdp2 example#2698
pstjohn wants to merge 1 commit intoNVIDIA:mainfrom
pstjohn:pstjohn/fsdp2-fused-adam

Conversation

@pstjohn
Copy link
Contributor

@pstjohn pstjohn commented Feb 22, 2026

Summary

  • Fix FusedAdam to work with PyTorch-native FSDP2 (fully_shard) when parameters are DTensor-wrapped Float8Tensor/QuantizedTensor
  • Fix fuse_wgrad_accumulation guard to avoid crashing with vanilla FSDP2 (previously assumed Megatron-Core FSDP exclusively)
  • Add examples for quantized_model_init on single-GPU (main.py) and multi-GPU FSDP2 (fully_shard.py)

Note: fuse_wgrad_accumulation remains incompatible with vanilla FSDP2

fuse_wgrad_accumulation still cannot be used with vanilla FSDP2. The feature writes weight gradients directly into main_grad and returns None to autograd, bypassing FSDP2's reduce-scatter. Each rank ends up with an unreduced gradient. Megatron-Core FSDP solves this by wiring get_main_grad() into its own reduce-scatter infrastructure. Vanilla FSDP2 does not yet expose an equivalent hook.

Fixes #2682

@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch 2 times, most recently from 22604c4 to 4d89e04 Compare February 23, 2026 15:28
@pstjohn pstjohn marked this pull request as ready for review February 23, 2026 17:27
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 23, 2026

Greptile Summary

This PR successfully enables FusedAdam to work with PyTorch-native FSDP2 when parameters are DTensor-wrapped Float8Tensor/QuantizedTensor objects. The implementation extracts local tensors from DTensors before initializing optimizer states and creating master weights, avoiding issues with __torch_dispatch__ ignoring dtype kwargs.

Key changes:

  • FusedAdam optimizer: properly unwraps DTensors and dequantizes QuantizedTensors when creating optimizer states and master weights
  • Tensor classes: added untyped_storage() methods for FSDP2 compatibility and converted view/reshape RuntimeErrors to warnings with dequantization fallback
  • Quantizers: added __getstate__ methods to exclude unpicklable process groups from serialization
  • Test suite: comprehensive FSDP2+FusedAdam tests covering all FP8 recipe types with appropriate xfail markers for unsupported configurations
  • Examples: clear demonstrations of quantized_model_init for single-GPU and multi-GPU FSDP2 scenarios

The PR correctly documents that fuse_wgrad_accumulation remains incompatible with vanilla FSDP2 (test expects failure with strict xfail).

Confidence Score: 4/5

  • This PR is safe to merge with low risk - it fixes critical FSDP2 compatibility issues with well-tested changes
  • Score of 4 reflects solid implementation with comprehensive testing. The core optimizer and tensor changes properly handle DTensor wrapping and QuantizedTensor dequantization. Extensive test suite covers multiple FP8 recipe types with appropriate xfail markers for known limitations. Minor concern is the default parameter value inconsistency in test helper function, though it's unused in practice.
  • tests/pytorch/distributed/test_torch_fsdp2.py - minor default parameter inconsistency at line 84

Important Files Changed

Filename Overview
tests/pytorch/distributed/test_torch_fsdp2.py adds comprehensive test suite for FSDP2+FusedAdam with FP8 recipes, includes proper xfail markers for unsupported configurations
tests/pytorch/distributed/run_fsdp2_fused_adam.py new comprehensive test harness for FSDP2+FusedAdam combinations, validates DTensor wrapping and QuantizedTensor compatibility
transformer_engine/pytorch/optimizers/fused_adam.py fixed DTensor+QuantizedTensor handling by extracting local tensors before state initialization and master weight creation
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py converted RuntimeErrors to warnings with dequantization fallback for view/reshape operations incompatible with FSDP2
transformer_engine/pytorch/tensor/nvfp4_tensor.py added untyped_storage method for FSDP2 compatibility and converted reshape errors to warnings with fallback
transformer_engine/pytorch/tensor/mxfp8_tensor.py added untyped_storage method required by FSDP2's reset_sharded_param for data pointer access

Sequence Diagram

sequenceDiagram
    participant User
    participant FSDP2 as FSDP2 fully_shard
    participant Param as DTensor[QuantizedTensor]
    participant FusedAdam
    participant State as Optimizer State

    User->>FSDP2: Apply sharding to model
    FSDP2->>Param: Wrap params as DTensor
    Note over Param: Parameters are DTensor-wrapped<br/>QuantizedTensor objects
    
    User->>FusedAdam: Create optimizer(params)
    FusedAdam->>FusedAdam: initialize_state(param)
    
    alt DTensor Check
        FusedAdam->>Param: Check isinstance(param, DTensor)
        FusedAdam->>Param: Extract param._local_tensor
        Note over FusedAdam: Avoid __torch_dispatch__<br/>dtype issues
    end
    
    alt QuantizedTensor Check
        FusedAdam->>Param: Check isinstance(local_tensor, QuantizedTensor)
        FusedAdam->>Param: local_tensor.dequantize(dtype=float32)
        Note over FusedAdam: Get plain float32 tensor<br/>for master weight
    end
    
    FusedAdam->>State: Create exp_avg (float32)
    FusedAdam->>State: Create exp_avg_sq (float32)
    FusedAdam->>State: Create master_param (float32)
    
    Note over State: All optimizer states are<br/>plain tensors (not DTensor/<br/>QuantizedTensor)
    
    User->>FusedAdam: optimizer.step()
    FusedAdam->>State: Update states with gradients
    FusedAdam->>Param: Update DTensor parameters
    Note over Param: DTensor wrapper and<br/>QuantizedTensor preserved
Loading

Last reviewed commit: 5f0ebab

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 13 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Member

@cspades cspades left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, clean edits.

@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch from 0103b53 to 3c3dbd2 Compare February 24, 2026 20:06
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@XueSongTap
Copy link

@pstjohn Hi, thanks for the great work! Does this PR plan to also handle the BF16 path? I noticed the BF16 branch still operates on the original p/p_grad without unwrapping when they're DTensors. In my experiments with FSDP2 + BF16, I'm seeing non-trivial overhead during the optimizer step from repeated DTensor dispatch. Curious if that's intentional or a planned follow-up.

@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch from a4d691f to 872caef Compare February 26, 2026 15:11
@pstjohn pstjohn marked this pull request as draft February 26, 2026 20:08
@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch 4 times, most recently from 9ccc0c3 to eb8606a Compare February 26, 2026 21:55
@pstjohn pstjohn marked this pull request as ready for review February 26, 2026 21:56
@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch from eb8606a to c2415e4 Compare February 26, 2026 22:50
Comment on lines +167 to +170
pytest.xfail(
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
"MXFP8 quantized tensors, causing illegal memory access"
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

claude's analysis:

Root cause: MXFP8Tensor inherits from QuantizedTensor but NOT from Float8Tensor. In fused_adam.py step():

  1. Line 617: isinstance(p, Float8Tensor) → False for MXFP8Tensor
  2. Line 629: p.dtype in [torch.float16, torch.bfloat16] → True (nominal dtype is bfloat16)
  3. So p.data (which is still an MXFP8Tensor wrapper with MXFP8-layout storage) gets added to p_f16_model
  4. The multi_tensor_adam CUDA kernel treats this as plain bf16 memory → illegal memory access

continue
# Extract local tensors from DTensors (e.g. from FSDP2)
# so that multi_tensor kernels receive plain CUDA tensors.
if isinstance(p_grad, DTensor):
Copy link
Collaborator

@vthumbe1503 vthumbe1503 Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really needed? Since p_grad is high precision, all Dtensor's op should get translated to local tensor's op without needing to extract it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, yeah I guess not -- reverted!

column-wise FP8 data.

"""
data = self._rowwise_data if self._rowwise_data is not None else self._columnwise_data
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please provide reason(maybe in the PR details) as to why this is needed? And where is this used? Is this needed for DCP checkpointing. And if so why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be in tests/pytorch/distributed/test_torch_fsdp2.py::test_fsdp2_dcp_output_parity_async[MXFP8BlockScaling],

[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 64, in fsdp_hook_wrapper
[rank1]:     return torch._dynamo.disable(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1252, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 257, in _pre_forward
[rank1]:     args, kwargs = self._root_pre_forward(module, args, kwargs)
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 133, in _root_pre_forward
[rank1]:     self._lazy_init()
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 203, in _lazy_init
[rank1]:     state._fsdp_param_group.lazy_init()
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 261, in lazy_init
[rank1]:     fsdp_param.reset_sharded_param()
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 890, in reset_sharded_param
[rank1]:     == local_tensor.untyped_storage().data_ptr()

but we're currently xfailing that test anyways because of the error in multi_tensor_apply,

[rank0]:   File "/workspace/transformer_engine/pytorch/optimizers/multi_tensor_apply.py", line 21, in __call__
[rank0]:     return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: /workspace/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh:92 in function multi_tensor_apply: CUDA Error: an illegal memory access was encountered

So i suppose we could revert this as well. I just added an additional description to the docstring as to why it's here

@vthumbe1503
Copy link
Collaborator

/te-ci L1 pytorch

vthumbe1503
vthumbe1503 previously approved these changes Feb 28, 2026
Copy link
Collaborator

@vthumbe1503 vthumbe1503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clean PR and great work. Left a few minor comments. LGTM post CI success.

if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling):
model_state = {
k: v for k, v in model.state_dict().items() if not k.endswith("_extra_state")
}
Copy link
Collaborator

@vthumbe1503 vthumbe1503 Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also please comment why we should avoid saving _extra_state. As in what error we get with dcp if we dont do so?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one I remember, the others I'll need to comment out and run the test suite again 😅

But this is a known hassle where torch DCP needs the sizes of these tensors to remain consistent during saving & loading, and since this is pickled data, it changes when there's data in that field.

The alternative is a detailed load_planner for DelayedScaling that reads and allocates the extra state data tensor

@pstjohn pstjohn force-pushed the pstjohn/fsdp2-fused-adam branch from 099a3ab to 5f0ebab Compare March 2, 2026 15:23
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 2, 2026

Additional Comments (1)

tests/pytorch/distributed/test_torch_fsdp2.py, line 84
default value uses snake_case but run_fsdp2_fused_adam.py argparse expects PascalCase. change to "DelayedScaling" for consistency

def _run_fused_adam_test(test_name, recipe="DelayedScaling"):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Example of quantized_model_init for low-precision compute weights, fp32 main weights using fused_adam with fsdp2

4 participants