Skip to content

fix: add setup_context for torch.func compatibility#7916

Open
roycho96 wants to merge 14 commits intodeepspeedai:masterfrom
roycho96:fix/support-func-torch
Open

fix: add setup_context for torch.func compatibility#7916
roycho96 wants to merge 14 commits intodeepspeedai:masterfrom
roycho96:fix/support-func-torch

Conversation

@roycho96
Copy link
Copy Markdown

LinearFunctionForZeroStage3 uses the legacy forward(ctx, ...) pattern which is incompatible with torch.func transforms (torch.func.grad, torch.func.grad_and_value, vmap, etc.):

RuntimeError: In order to use an autograd.Function with functorch transforms
(vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod.

This affects any library that uses torch.func internally on a ZeRO-3 model.

Fix

Fixes #7913

Note

As pointed out by @zhangj1an in #7913, PostBackwardFunctionModule and PreBackwardFunctionForModule in parameter_offload.py have the same issue. Those will be addressed in a follow-up commit within this PR.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: eed37042bc

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +61 to +64
@staticmethod
def setup_context(ctx, inputs, output):
input, weight, bias = inputs
ctx.save_for_backward(input, weight, bias)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve AMP context when using separate setup_context

On the _SUPPORTS_SETUP_CONTEXT branch, forward() remains wrapped with @custom_fwd and backward() with @custom_bwd, but ctx is now populated via an undecorated setup_context(). PyTorch AMP currently depends on custom_fwd seeding autocast state on ctx; with the split form, autocast-enabled backward can fail with AttributeError: ... _fwd_used_autocast, which breaks the mixed-precision path already covered by tests/unit/runtime/test_autocast.py.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up on codex's commet, I suggest for Pytorch >=2.0, for forward, this line should be removed, @autocast_custom_fwd. So forward will look like this,

@staticmethod
# bias is an optional argument
def forward(input, weight, bias=None):
    if input.dim() == 2 and bias is not None:
        # fused op is marginally faster
        ret = torch.addmm(bias, input, weight.t())
    else:
        output = input.matmul(weight.t())
        if bias is not None:
            output += bias
        ret = output
    return ret

Reason being, according to line 35, @autocast_custom_fwd is alias of torch.amp.autocast_mode.py:custom_fwd. Its nested decorate_fwd method looks like this:

@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
        args[0]._dtype = torch.get_autocast_dtype(device_type)
        if cast_inputs is None:
            args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
            return fwd(*args, **kwargs)
        else:
            ...
    return decorate_fwd

the args[0] is actually treated as cxt (e.g. args[0]._dtype, args[0]._fwd_used_autocast), which is no longer available after we remove the deprecated cxt from forward method. We can re-implement these few lines in setup_context like below,

@staticmethod
def setup_context(ctx, inputs, output):
    device_type = get_accelerator().device_name()
    ctx._dtype = torch.get_autocast_dtype(device_type)
    ctx._fwd_used_autocast = torch.is_autocast_enabled(device_type)
    input, weight, bias = inputs[0], inputs[1], inputs[2] if len(inputs) > 2 else None
    ctx.save_for_backward(input, weight, bias)

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
… unpack error

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
@roycho96 roycho96 force-pushed the fix/support-func-torch branch from 252aea1 to 39b1755 Compare March 21, 2026 10:28
Copy link
Copy Markdown

@zhangj1an zhangj1an left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! I implemented the same fix, so it looks good to me. To reduce reviewer's effort, I had 2 minor comments. This should lead to linear.py to only have 17 insertions (+) and 7 deletions (-).

Comment on lines +61 to +64
@staticmethod
def setup_context(ctx, inputs, output):
input, weight, bias = inputs
ctx.save_for_backward(input, weight, bias)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up on codex's commet, I suggest for Pytorch >=2.0, for forward, this line should be removed, @autocast_custom_fwd. So forward will look like this,

@staticmethod
# bias is an optional argument
def forward(input, weight, bias=None):
    if input.dim() == 2 and bias is not None:
        # fused op is marginally faster
        ret = torch.addmm(bias, input, weight.t())
    else:
        output = input.matmul(weight.t())
        if bias is not None:
            output += bias
        ret = output
    return ret

Reason being, according to line 35, @autocast_custom_fwd is alias of torch.amp.autocast_mode.py:custom_fwd. Its nested decorate_fwd method looks like this:

@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
        args[0]._dtype = torch.get_autocast_dtype(device_type)
        if cast_inputs is None:
            args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
            return fwd(*args, **kwargs)
        else:
            ...
    return decorate_fwd

the args[0] is actually treated as cxt (e.g. args[0]._dtype, args[0]._fwd_used_autocast), which is no longer available after we remove the deprecated cxt from forward method. We can re-implement these few lines in setup_context like below,

@staticmethod
def setup_context(ctx, inputs, output):
    device_type = get_accelerator().device_name()
    ctx._dtype = torch.get_autocast_dtype(device_type)
    ctx._fwd_used_autocast = torch.is_autocast_enabled(device_type)
    input, weight, bias = inputs[0], inputs[1], inputs[2] if len(inputs) > 2 else None
    ctx.save_for_backward(input, weight, bias)


# PyTorch >= 2.0 supports setup_context, which is required for
# torch.func transforms (vmap, grad, jvp, jacrev, etc.)
_SUPPORTS_SETUP_CONTEXT = hasattr(torch.autograd.Function, 'setup_context')
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to implement a version check too, but cannot find a way to avoid duplicated code. What do you think about adding a error log like below,

def _require_pt2_zero3_linear():
    if int(torch.__version__.split(".")[0]) < 2:
        raise RuntimeError("LinearFunctionForZeroStage3 requires PyTorch >= 2.0")

and use it in LinearModuleForZeroStage3's forward method?

For pytorch 1.9, we will just deprecate it by adding a new line in README.md in the requirements section,

For ZeRO Stage 3,  `PyTorch >= 2.0` is required, as parameter-offload hooks rely on `torch.autograd.Function.setup_context`.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! You're right.

Also, I agree the version branch duplication isn't ideal. However, I think dropping PyTorch < 2.0 support is a policy decision that should come from the maintainers

@zhangj1an
Copy link
Copy Markdown

zhangj1an commented Mar 21, 2026

Hey @roycho96, I used this script to reproduce your issue #7913.
repro_zero3_functorch.py

After I implement the fixes in parameter_offload.py, I am still not able to run this file normally, it returns segmentation fault. I will need more time to investigate what is happening.

I am not authorised to push directly into your branch. In case I blocked your PR, here is my fix for parameter_offload.py.
parameter_offload.py Feel free to update it on your end!

@roycho96
Copy link
Copy Markdown
Author

Thanks for the work! I implemented the same fix, so it looks good to me. To reduce reviewer's effort, I had 2 minor comments. This should lead to linear.py to only have 17 insertions (+) and 7 deletions (-).

Hi @zhangj1an, I've sent you a collaborator invite to my fork. Feel free to push your fix directly to the branch. Thanks for the suggestion!

… setup_context

Co-authored-by: zhangj1an <jianmusings@gmail.com>
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
@zhangj1an zhangj1an force-pushed the fix/support-func-torch branch from 444122c to 6df37af Compare March 22, 2026 08:45
…afe linear

Avoid asymmetric custom_bwd without custom_fwd on the setup_context forward path;
mirror forward AMP in backward via torch.amp.autocast.

Signed-off-by: Zhang <jianmusings@gmail.com>
PyTorch versions that expose autograd.Function.setup_context need the
modern forward + setup_context shape for torch.func / functorch.

Signed-off-by: Zhang <jianmusings@gmail.com>
@zhangj1an zhangj1an force-pushed the fix/support-func-torch branch from 0a66444 to 5e83d05 Compare March 22, 2026 09:34
@zhangj1an
Copy link
Copy Markdown

zhangj1an commented Mar 22, 2026

Hey @roycho96 , thanks for the branch invite! I finished the fix required from my side.

summary of fix
  • 5e83d05: implemented the fix in parameter_offload.py as discussed.
  • [minor] 6df37af: signed your commit 444122c to pass DCO check.
  • [minor] c0b9694: in LinearFunctionForZeroStage3, removed @autocast_custom_bwd to be consistent with the removal of @autocast_custom_fwd in 444122c

However, now I experience a different error. Assume when Zero-3 is enabled, a model that contains F.linear goes through ZeRO-3’s forward method in LinearFunctionForZeroStage3. When calling LinearFunctionForZeroStage3.apply, the lines including weight.t(), torch.addmm, and matmul will trigger segmentation fault in functorch. (GDB shows the exact place is at at::functorch::TensorWrapper::is_alive())

Due to this error, I cannot yet find a verification script that runs on this branch and fails on main branch. I'm not sure how to go about to solve it. Do you have any idea? 🤔 Thanks in advance!

Update: found a verification script. My previous script did not go through torch.run properly and used dummy arguments, which caused segmentation fault error.

zhangj1an and others added 9 commits March 24, 2026 22:21
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
@zhangj1an
Copy link
Copy Markdown

zhangj1an commented Mar 25, 2026

Hey @roycho96, I think this PR is ready for review! Below is my attempt for the updated PR description. Feel free to make any changes.

Summary

Fixes #7913

ZeRO-3’s custom linear autograd.Function used the legacy forward(ctx, …) + ctx.save_for_backward pattern. On recent PyTorch (e.g. 2.8+), any torch.func transform (grad, grad_and_value, vmap, …) over that path raises:

RuntimeError: In order to use an autograd.Function with functorch transforms … it must override the setup_context staticmethod.

This breaks real stacks that combine ZeRO-3 with functorch-based code (e.g. some Liger-Kernel / Axolotl KD paths). More details in the issue above.

Fix

  • deepspeed/runtime/zero/linear.py : For LinearFunctionForZeroStage3, adopt the autograd.Function pattern PyTorch expects under functorch: non-context forward + setup_context (saved tensors moved into setup_context). backward is unchanged.
  • deepspeed/runtime/zero/parameter_offload.py : Same change for ZeRO-3 offload autograd.Function definitions that used ctx inside the old-style forward.
  • Regression test: tests/unit/v1/zero/test_zero_functorch_linear.py (skips if torch.autograd.Function.setup_context is missing, e.g. PyTorch < 2.0).

When setup_context is not available, the legacy forward(ctx, …) path is retained so older PyTorch builds are unaffected.

How to verify

The bug can be reproduced on main branch (DeepSpeed 0.18.8+…) with PyTorch 2.8.

Two temporary scripts are added to help reviewers compare this branch vs master.

File Role
scripts/repro_pr7916.py Minimal repro: ZeRO-3 via deepspeed.initialize, then torch.func.grad_and_value is called via zero3_linear_wrap .
scripts/setup_pr7916.sh Creates a .venvs/pr7916 with PyTorch 2.8 + cu128, runs the repro on this branch then on master.

How to run

./scripts/setup_pr7916.sh
# Optional: wipe venv and reinstall from scratch
./scripts/setup_pr7916.sh --force-install

Above script will print result from both branches. Alternatively, if the venv is already set up, you can run

source .venvs/pr7916/bin/activate
torchrun --standalone --nproc_per_node=1 scripts/repro_pr7916.py

Expected output

  • This branch: completes; rank 0 prints
    repro: grad_and_value over zero3_linear_wrap (LinearFunctionForZeroStage3) OK.
  • master (unfixed): fails with the functorch error, RuntimeError: In order to use an autograd.Function with functorch transforms … it must override the setup_context staticmethod.

Test plan

  • pytest tests/unit/v1/zero/test_zero_functorch_linear.py

Cleanup

  • Remove scripts/repro_pr7916.py, scripts/setup_pr7916.sh, after PR approval.

Example environment where the bug was reproduced

Item Value
OS Ubuntu 22.04
GPU NVIDIA H100 80GB PCIe
Python 3.11
PyTorch 2.8.0+cu128
PyTorch CUDA (wheel) 12.8
DeepSpeed (original report) 0.16.4 (wheel)

@roycho96 roycho96 marked this pull request as ready for review March 25, 2026 14:35
@roycho96 roycho96 requested a review from loadams as a code owner March 25, 2026 14:35
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 60d20da79f

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +78 to +81
if getattr(ctx, "_fwd_used_autocast", False):
with torch.amp.autocast(device_type=device_type, enabled=True, dtype=ctx._dtype):
return LinearFunctionForZeroStage3._backward_core(ctx, grad_output)
return LinearFunctionForZeroStage3._backward_core(ctx, grad_output)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Restore backward autocast state for both true and false cases

Fresh evidence in this commit: the new setup_context path only enters an autocast context when ctx._fwd_used_autocast is true, and otherwise runs backward without forcing autocast off. @custom_bwd semantics are to run backward with enabled=ctx._fwd_used_autocast, so if backward executes inside an outer autocast-enabled region while forward ran without autocast, gradients are now computed under the wrong autocast state compared with the legacy path.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] LinearFunctionForZeroStage3 crashes with torch.func transforms (missing setup_context)

2 participants