Add fused_adam, quantized_model_init, and fsdp2 example#2698
Add fused_adam, quantized_model_init, and fsdp2 example#2698pstjohn wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
22604c4 to
4d89e04
Compare
Greptile SummaryThis PR successfully enables Key changes:
The PR correctly documents that Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant FSDP2 as FSDP2 fully_shard
participant Param as DTensor[QuantizedTensor]
participant FusedAdam
participant State as Optimizer State
User->>FSDP2: Apply sharding to model
FSDP2->>Param: Wrap params as DTensor
Note over Param: Parameters are DTensor-wrapped<br/>QuantizedTensor objects
User->>FusedAdam: Create optimizer(params)
FusedAdam->>FusedAdam: initialize_state(param)
alt DTensor Check
FusedAdam->>Param: Check isinstance(param, DTensor)
FusedAdam->>Param: Extract param._local_tensor
Note over FusedAdam: Avoid __torch_dispatch__<br/>dtype issues
end
alt QuantizedTensor Check
FusedAdam->>Param: Check isinstance(local_tensor, QuantizedTensor)
FusedAdam->>Param: local_tensor.dequantize(dtype=float32)
Note over FusedAdam: Get plain float32 tensor<br/>for master weight
end
FusedAdam->>State: Create exp_avg (float32)
FusedAdam->>State: Create exp_avg_sq (float32)
FusedAdam->>State: Create master_param (float32)
Note over State: All optimizer states are<br/>plain tensors (not DTensor/<br/>QuantizedTensor)
User->>FusedAdam: optimizer.step()
FusedAdam->>State: Update states with gradients
FusedAdam->>Param: Update DTensor parameters
Note over Param: DTensor wrapper and<br/>QuantizedTensor preserved
Last reviewed commit: 5f0ebab |
0103b53 to
3c3dbd2
Compare
|
@pstjohn Hi, thanks for the great work! Does this PR plan to also handle the BF16 path? I noticed the BF16 branch still operates on the original p/p_grad without unwrapping when they're DTensors. In my experiments with FSDP2 + BF16, I'm seeing non-trivial overhead during the optimizer step from repeated DTensor dispatch. Curious if that's intentional or a planned follow-up. |
a4d691f to
872caef
Compare
9ccc0c3 to
eb8606a
Compare
eb8606a to
c2415e4
Compare
| pytest.xfail( | ||
| "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " | ||
| "MXFP8 quantized tensors, causing illegal memory access" | ||
| ) |
There was a problem hiding this comment.
claude's analysis:
Root cause: MXFP8Tensor inherits from QuantizedTensor but NOT from Float8Tensor. In fused_adam.py step():
- Line 617: isinstance(p, Float8Tensor) → False for MXFP8Tensor
- Line 629: p.dtype in [torch.float16, torch.bfloat16] → True (nominal dtype is bfloat16)
- So p.data (which is still an MXFP8Tensor wrapper with MXFP8-layout storage) gets added to p_f16_model
- The multi_tensor_adam CUDA kernel treats this as plain bf16 memory → illegal memory access
| continue | ||
| # Extract local tensors from DTensors (e.g. from FSDP2) | ||
| # so that multi_tensor kernels receive plain CUDA tensors. | ||
| if isinstance(p_grad, DTensor): |
There was a problem hiding this comment.
Is this really needed? Since p_grad is high precision, all Dtensor's op should get translated to local tensor's op without needing to extract it.
There was a problem hiding this comment.
hmm, yeah I guess not -- reverted!
| column-wise FP8 data. | ||
|
|
||
| """ | ||
| data = self._rowwise_data if self._rowwise_data is not None else self._columnwise_data |
There was a problem hiding this comment.
Can you please provide reason(maybe in the PR details) as to why this is needed? And where is this used? Is this needed for DCP checkpointing. And if so why?
There was a problem hiding this comment.
This would be in tests/pytorch/distributed/test_torch_fsdp2.py::test_fsdp2_dcp_output_parity_async[MXFP8BlockScaling],
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 64, in fsdp_hook_wrapper
[rank1]: return torch._dynamo.disable(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1252, in _fn
[rank1]: return fn(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 257, in _pre_forward
[rank1]: args, kwargs = self._root_pre_forward(module, args, kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 133, in _root_pre_forward
[rank1]: self._lazy_init()
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 203, in _lazy_init
[rank1]: state._fsdp_param_group.lazy_init()
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 261, in lazy_init
[rank1]: fsdp_param.reset_sharded_param()
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 890, in reset_sharded_param
[rank1]: == local_tensor.untyped_storage().data_ptr()
but we're currently xfailing that test anyways because of the error in multi_tensor_apply,
[rank0]: File "/workspace/transformer_engine/pytorch/optimizers/multi_tensor_apply.py", line 21, in __call__
[rank0]: return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: /workspace/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh:92 in function multi_tensor_apply: CUDA Error: an illegal memory access was encountered
So i suppose we could revert this as well. I just added an additional description to the docstring as to why it's here
|
/te-ci L1 pytorch |
| if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling): | ||
| model_state = { | ||
| k: v for k, v in model.state_dict().items() if not k.endswith("_extra_state") | ||
| } |
There was a problem hiding this comment.
Can you also please comment why we should avoid saving _extra_state. As in what error we get with dcp if we dont do so?
There was a problem hiding this comment.
This one I remember, the others I'll need to comment out and run the test suite again 😅
But this is a known hassle where torch DCP needs the sizes of these tensors to remain consistent during saving & loading, and since this is pickled data, it changes when there's data in that field.
The alternative is a detailed load_planner for DelayedScaling that reads and allocates the extra state data tensor
…hard Signed-off-by: Peter St. John <[email protected]>
099a3ab to
5f0ebab
Compare
Additional Comments (1)
|
Summary
FusedAdamto work with PyTorch-native FSDP2 (fully_shard) when parameters areDTensor-wrappedFloat8Tensor/QuantizedTensorfuse_wgrad_accumulationguard to avoid crashing with vanilla FSDP2 (previously assumed Megatron-Core FSDP exclusively)quantized_model_initon single-GPU (main.py) and multi-GPU FSDP2 (fully_shard.py)Note:
fuse_wgrad_accumulationremains incompatible with vanilla FSDP2fuse_wgrad_accumulationstill cannot be used with vanilla FSDP2. The feature writes weight gradients directly intomain_gradand returnsNoneto autograd, bypassing FSDP2's reduce-scatter. Each rank ends up with an unreduced gradient. Megatron-Core FSDP solves this by wiringget_main_grad()into its own reduce-scatter infrastructure. Vanilla FSDP2 does not yet expose an equivalent hook.Fixes #2682