Skip to content

[JAX] Deprecate GSPMD: remove infer_sharding_from_operands and GSPMD tests#2702

Open
phu0ngng wants to merge 14 commits intoNVIDIA:mainfrom
phu0ngng:rm_gspmd
Open

[JAX] Deprecate GSPMD: remove infer_sharding_from_operands and GSPMD tests#2702
phu0ngng wants to merge 14 commits intoNVIDIA:mainfrom
phu0ngng:rm_gspmd

Conversation

@phu0ngng
Copy link
Collaborator

Description

GSPMD sharding propagation is being deprecated in favour of Shardy, which is now the default JAX partitioner. This commit removes all GSPMD-related code paths and tests:

  • Drop the infer_sharding_from_operands abstract method from BasePrimitive and remove it from def_partition() registration
  • Remove all infer_sharding_from_operands implementations across cpp_extensions: activation, amax, attention, gemm, normalization, quantization, and softmax primitives
  • Remove stale "Keep in sync with infer_sharding_from_operands" comments from FusedAttn shardy_sharding_rule methods
  • Drop all use_shardy=False (GSPMD) distributed test paths and the jax.config.update("jax_use_shardy_partitioner", ...) config calls
  • Consolidate paired GSPMD/Shardy test functions into single tests and strip _shardy suffixes from test names

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 24, 2026

Greptile Summary

This PR successfully deprecates GSPMD sharding propagation in favor of Shardy, which is now the default JAX partitioner. The changes are comprehensive and systematic:

Core Changes:

  • Removed the abstract method infer_sharding_from_operands from BasePrimitive in base.py
  • Removed infer_sharding_from_operands from def_partition() registration
  • Deleted all infer_sharding_from_operands implementations across 9 primitive files (activation, amax, attention, gemm, normalization, quantization, softmax in cpp_extensions, plus permutation in triton_extensions)
  • Cleaned up stale "Keep in sync with infer_sharding_from_operands" comments from FusedAttn shardy_sharding_rule methods

Test Consolidation:

  • Removed use_shardy parameter from all distributed tests
  • Removed jax.config.update("jax_use_shardy_partitioner", ...) config calls
  • Consolidated duplicate GSPMD/Shardy test methods into single tests
  • Stripped _shardy suffixes from test names where applicable

Completeness:
The PR addresses the concern raised in the previous review thread - the permutation.py file was updated to remove all infer_sharding_from_operands methods from triton extension primitives. All GSPMD references have been thoroughly removed from the codebase.

Confidence Score: 5/5

  • This PR is safe to merge - it's a clean, systematic removal of deprecated GSPMD code with comprehensive test coverage preserved
  • The refactoring is thorough and consistent across all affected files. The deprecation is complete - all infer_sharding_from_operands implementations have been removed, all GSPMD test paths eliminated, and the code now relies solely on Shardy. No partial migrations or mixed states remain.
  • No files require special attention - all changes follow a consistent pattern

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/base.py Removed abstract method infer_sharding_from_operands and its registration in def_partition(), cleanly deprecating GSPMD support
transformer_engine/jax/cpp_extensions/gemm.py Removed infer_sharding_from_operands implementation from GemmPrimitive
transformer_engine/jax/cpp_extensions/attention.py Removed infer_sharding_from_operands from FusedAttn primitives and cleaned up stale sync comments in shardy_sharding_rule methods
transformer_engine/jax/triton_extensions/permutation.py Removed infer_sharding_from_operands from all permutation primitives (RowIdMap, Permute, Unpermute, MakeChunkSortMap, SortChunksByMap)
tests/jax/test_distributed_fused_attn.py Consolidated GSPMD/Shardy tests by removing use_shardy parameter and duplicate _shardy test variants
examples/jax/encoder/test_model_parallel_encoder.py Removed --enable-shardy CLI argument and all _shardy test variants (BF16, FP8, MXFP8, NVFP4 with/without SP)

Last reviewed commit: dd46149

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

23 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@phu0ngng
Copy link
Collaborator Author

/te-ci JAX L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

19 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

phu0ngng and others added 11 commits February 26, 2026 16:29
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…tests

GSPMD sharding propagation is being deprecated in favour of Shardy,
which is now the default JAX partitioner. This commit removes all
GSPMD-related code paths and tests:

- Drop the infer_sharding_from_operands abstract method from
  BasePrimitive and remove it from def_partition() registration
- Remove all infer_sharding_from_operands implementations across
  cpp_extensions: activation, amax, attention, gemm, normalization,
  quantization, and softmax primitives
- Remove stale "Keep in sync with infer_sharding_from_operands"
  comments from FusedAttn shardy_sharding_rule methods
- Drop all use_shardy=False (GSPMD) distributed test paths and the
  jax.config.update("jax_use_shardy_partitioner", ...) config calls
- Consolidate paired GSPMD/Shardy test functions into single tests
  and strip _shardy suffixes from test names

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Comment on lines 201 to 206
batching.primitive_batchers[outer_p] = cls.batcher
outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
outer_p_lower.def_partition(
infer_sharding_from_operands=cls.infer_sharding_from_operands,
partition=cls.partition,
sharding_rule=cls.shardy_sharding_rule,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing infer_sharding_from_operands from def_partition() will affect all primitives that inherit from BasePrimitive. The file transformer_engine/jax/triton_extensions/permutation.py contains multiple primitives (RowIdMapPass1Primitive, RowIdMapPass2Primitive, etc.) that still define infer_sharding_from_operands methods but are not updated in this PR. After this change, those methods will no longer be registered or used, potentially causing different sharding behavior for triton extension primitives.

phu0ngng added 2 commits March 2, 2026 07:48
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng
Copy link
Collaborator Author

phu0ngng commented Mar 2, 2026

/te-ci JAX L1

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once CI finishes, thanks for making this change!

This full removal is valid as users will be in one of the following situations, right?

a) The user is on the latest JAX version, in which case having these GSPMD functions around could cause errors since GSPMD has been removed, so removing this logic as we do in this PR is correct
b) The user is on an older version of JAX, in which case they can use GSPMD or Shardy. There have been updates from JAX itself about this transition so users have been aware they need to move to Shardy by March 2026, so in our case removing it should still be okay
c) The user is on a very old version of JAX (e.g. >1 year old), in which Shardy doesn't work or has bugs. In this case, they might also have other compatibility issues with TE/JAX, in which case they should update to a more recent JAX version

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants