Skip to content

[PyTorch] Remove is_first_microbatch setting after cudagraph warmup#2715

Merged
yaox12 merged 2 commits intoNVIDIA:mainfrom
buptzyb:robinz/graph_isfirstmicrobatch
Mar 2, 2026
Merged

[PyTorch] Remove is_first_microbatch setting after cudagraph warmup#2715
yaox12 merged 2 commits intoNVIDIA:mainfrom
buptzyb:robinz/graph_isfirstmicrobatch

Conversation

@buptzyb
Copy link
Contributor

@buptzyb buptzyb commented Feb 27, 2026

Description

Reset is_first_microbatch to True is unnecessary for fp8, because later we'll set it back to False: https://github.com/NVIDIA/TransformerEngine/blob/release_v2.12/transformer_engine/pytorch/module/layernorm_mlp.py#L2037-L2042. It was used to control fp8 cast-transpose behavior, which is overtaken by the skip_fp8_weight_update tensor in graph replay.

For bf16, it's not only unnecessary but also conflicts with an optimization in mcore: when PP is not enabled, we capture only one graph per layer and share it across microbatches. is_first_microbatch controls the behavior of fused wgrad: D=A*B for the first microbatch, and D=A*B+D for the following microbatches. We should always capture the D=A*B+D version (i.e., capture with is_first_microbatch=False) to make the graph work for all microbatches. Without reset, it must be is_first_microbatch=False after warmup.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Robin Zhang <robinz@nvidia.com>
@buptzyb buptzyb force-pushed the robinz/graph_isfirstmicrobatch branch from 886f8cb to 25af05a Compare February 27, 2026 08:27
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 27, 2026

Greptile Summary

Removes the post-warmup reset of is_first_microbatch attribute to True on modules. This code was originally added for Megatron-Core (MCore) compatibility to prevent warmup from altering control flow. The author argues this reset is unnecessary because during CUDA graph capture, the skip_fp8_weight_update tensor now controls FP8 weight update behavior, and is_first_microbatch is automatically set to False when this tensor is present. For wgrad accumulation, capturing the no-clear version (is_first_microbatch=False) allows the graph to work correctly across all microbatches rather than just the first one.

Confidence Score: 3/5

  • Safe to merge with thorough testing, especially with MCore integration
  • The change is small and well-explained, but it removes code specifically added for Megatron-Core compatibility without including tests to validate the new behavior. The reasoning about skip_fp8_weight_update controlling the behavior is sound, but the removed code operated on module attributes while the explanation focuses on function parameters. This needs validation with MCore workloads to ensure no regression.
  • No files require special attention beyond standard MCore integration testing

Important Files Changed

Filename Overview
transformer_engine/pytorch/graph.py Removes post-warmup is_first_microbatch attribute reset that was added for MCore compatibility

Last reviewed commit: 7abebf0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@yaox12
Copy link
Member

yaox12 commented Mar 2, 2026

/te-ci pytorch

for module in func.modules():
if hasattr(module, "is_first_microbatch"):
module.is_first_microbatch = True
torch.cuda.synchronize()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that MCore modules no longer rely on the is_first_microbatch attribute being reset after warmup. The removed code handled module attributes while the PR description discusses function parameters - these are different mechanisms. Verify that removing this reset doesn't break MCore integration where modules may have is_first_microbatch as an instance attribute that affects control flow.

Copy link
Member

@yaox12 yaox12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@yaox12 yaox12 merged commit f508e66 into NVIDIA:main Mar 2, 2026
21 of 24 checks passed
phu0ngng pushed a commit to phu0ngng/TransformerEngine that referenced this pull request Mar 2, 2026
…NVIDIA#2715)

Remove is_first_microbatch setting after warmup

Signed-off-by: Robin Zhang <robinz@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants