Skip to content

pass params_dtype to qk_norm creation#2718

Open
pstjohn wants to merge 3 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/qk-norm-dtype
Open

pass params_dtype to qk_norm creation#2718
pstjohn wants to merge 3 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/qk-norm-dtype

Conversation

@pstjohn
Copy link
Contributor

@pstjohn pstjohn commented Feb 28, 2026

Previously layers would fail with

            assert (
>               query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype
            ), "Queries, keys and values must have the same data type!"
E           AssertionError: Queries, keys and values must have the same data type!

transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py:1063: AssertionError

if you created a layer with dtype != float32. This ensures the dtype of the layernorm layers match those of the base attention layer.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 28, 2026

Greptile Summary

This PR fixes a dtype mismatch bug in MultiheadAttention when using QK normalization with non-float32 dtypes. The fix ensures RMSNorm and LayerNorm normalization modules receive the params_dtype parameter, preventing assertion failures when queries, keys, and values have different dtypes.

Key changes:

  • Modified _create_qk_norm_modules() to accept and pass params_dtype to RMSNorm/LayerNorm constructors
  • L2Normalization correctly excluded (parameter-free operation)
  • Added test parametrization for torch.float32 and torch.bfloat16 to verify the fix

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk - it's a targeted bug fix with comprehensive test coverage
  • The fix is straightforward and correct: it passes the existing params_dtype parameter through to normalization layers to ensure dtype consistency. The change is well-tested with parametrized tests covering both float32 and bfloat16. L2Normalization is correctly excluded as it's parameter-free. No breaking changes or edge cases identified.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/multi_head_attention.py Correctly passes params_dtype parameter to QK normalization layers (RMSNorm/LayerNorm) to ensure dtype consistency with the main attention layer
tests/pytorch/test_qk_norm.py Added comprehensive test coverage for different params_dtype values (float32, bfloat16) to verify the dtype fix works correctly

Last reviewed commit: 4db2067

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yaox12
Copy link
Member

yaox12 commented Mar 2, 2026

/te-ci pytorch

@yaox12
Copy link
Member

yaox12 commented Mar 2, 2026

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants