[JAX] Deprecate GSPMD: remove infer_sharding_from_operands and GSPMD tests#2702
[JAX] Deprecate GSPMD: remove infer_sharding_from_operands and GSPMD tests#2702phu0ngng wants to merge 14 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR successfully deprecates GSPMD sharding propagation in favor of Shardy, which is now the default JAX partitioner. The changes are comprehensive and systematic: Core Changes:
Test Consolidation:
Completeness: Confidence Score: 5/5
Important Files Changed
Last reviewed commit: dd46149 |
|
/te-ci JAX L1 |
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…tests
GSPMD sharding propagation is being deprecated in favour of Shardy,
which is now the default JAX partitioner. This commit removes all
GSPMD-related code paths and tests:
- Drop the infer_sharding_from_operands abstract method from
BasePrimitive and remove it from def_partition() registration
- Remove all infer_sharding_from_operands implementations across
cpp_extensions: activation, amax, attention, gemm, normalization,
quantization, and softmax primitives
- Remove stale "Keep in sync with infer_sharding_from_operands"
comments from FusedAttn shardy_sharding_rule methods
- Drop all use_shardy=False (GSPMD) distributed test paths and the
jax.config.update("jax_use_shardy_partitioner", ...) config calls
- Consolidate paired GSPMD/Shardy test functions into single tests
and strip _shardy suffixes from test names
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
| batching.primitive_batchers[outer_p] = cls.batcher | ||
| outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) | ||
| outer_p_lower.def_partition( | ||
| infer_sharding_from_operands=cls.infer_sharding_from_operands, | ||
| partition=cls.partition, | ||
| sharding_rule=cls.shardy_sharding_rule, | ||
| ) |
There was a problem hiding this comment.
Removing infer_sharding_from_operands from def_partition() will affect all primitives that inherit from BasePrimitive. The file transformer_engine/jax/triton_extensions/permutation.py contains multiple primitives (RowIdMapPass1Primitive, RowIdMapPass2Primitive, etc.) that still define infer_sharding_from_operands methods but are not updated in this PR. After this change, those methods will no longer be registered or used, potentially causing different sharding behavior for triton extension primitives.
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
/te-ci JAX L1 |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM once CI finishes, thanks for making this change!
This full removal is valid as users will be in one of the following situations, right?
a) The user is on the latest JAX version, in which case having these GSPMD functions around could cause errors since GSPMD has been removed, so removing this logic as we do in this PR is correct
b) The user is on an older version of JAX, in which case they can use GSPMD or Shardy. There have been updates from JAX itself about this transition so users have been aware they need to move to Shardy by March 2026, so in our case removing it should still be okay
c) The user is on a very old version of JAX (e.g. >1 year old), in which Shardy doesn't work or has bugs. In this case, they might also have other compatibility issues with TE/JAX, in which case they should update to a more recent JAX version
Description
GSPMD sharding propagation is being deprecated in favour of Shardy, which is now the default JAX partitioner. This commit removes all GSPMD-related code paths and tests:
Type of change
Checklist: