[PyTorch] Remove is_first_microbatch setting after cudagraph warmup#2715
[PyTorch] Remove is_first_microbatch setting after cudagraph warmup#2715yaox12 merged 2 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Robin Zhang <robinz@nvidia.com>
886f8cb to
25af05a
Compare
Greptile SummaryRemoves the post-warmup reset of Confidence Score: 3/5
Important Files Changed
Last reviewed commit: 7abebf0 |
|
/te-ci pytorch |
| for module in func.modules(): | ||
| if hasattr(module, "is_first_microbatch"): | ||
| module.is_first_microbatch = True | ||
| torch.cuda.synchronize() |
There was a problem hiding this comment.
Check that MCore modules no longer rely on the is_first_microbatch attribute being reset after warmup. The removed code handled module attributes while the PR description discusses function parameters - these are different mechanisms. Verify that removing this reset doesn't break MCore integration where modules may have is_first_microbatch as an instance attribute that affects control flow.
…NVIDIA#2715) Remove is_first_microbatch setting after warmup Signed-off-by: Robin Zhang <robinz@nvidia.com>
Description
Reset
is_first_microbatchto True is unnecessary for fp8, because later we'll set it back to False: https://github.com/NVIDIA/TransformerEngine/blob/release_v2.12/transformer_engine/pytorch/module/layernorm_mlp.py#L2037-L2042. It was used to control fp8 cast-transpose behavior, which is overtaken by theskip_fp8_weight_updatetensor in graph replay.For bf16, it's not only unnecessary but also conflicts with an optimization in mcore: when PP is not enabled, we capture only one graph per layer and share it across microbatches.
is_first_microbatchcontrols the behavior of fused wgrad:D=A*Bfor the first microbatch, andD=A*B+Dfor the following microbatches. We should always capture theD=A*B+Dversion (i.e., capture withis_first_microbatch=False) to make the graph work for all microbatches. Without reset, it must beis_first_microbatch=Falseafter warmup.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: