fix: add setup_context for torch.func compatibility#7916
fix: add setup_context for torch.func compatibility#7916roycho96 wants to merge 14 commits intodeepspeedai:masterfrom
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: eed37042bc
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| @staticmethod | ||
| def setup_context(ctx, inputs, output): | ||
| input, weight, bias = inputs | ||
| ctx.save_for_backward(input, weight, bias) |
There was a problem hiding this comment.
Preserve AMP context when using separate
setup_context
On the _SUPPORTS_SETUP_CONTEXT branch, forward() remains wrapped with @custom_fwd and backward() with @custom_bwd, but ctx is now populated via an undecorated setup_context(). PyTorch AMP currently depends on custom_fwd seeding autocast state on ctx; with the split form, autocast-enabled backward can fail with AttributeError: ... _fwd_used_autocast, which breaks the mixed-precision path already covered by tests/unit/runtime/test_autocast.py.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Following up on codex's commet, I suggest for Pytorch >=2.0, for forward, this line should be removed, @autocast_custom_fwd. So forward will look like this,
@staticmethod
# bias is an optional argument
def forward(input, weight, bias=None):
if input.dim() == 2 and bias is not None:
# fused op is marginally faster
ret = torch.addmm(bias, input, weight.t())
else:
output = input.matmul(weight.t())
if bias is not None:
output += bias
ret = output
return ret
Reason being, according to line 35, @autocast_custom_fwd is alias of torch.amp.autocast_mode.py:custom_fwd. Its nested decorate_fwd method looks like this:
@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
args[0]._dtype = torch.get_autocast_dtype(device_type)
if cast_inputs is None:
args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
return fwd(*args, **kwargs)
else:
...
return decorate_fwd
the args[0] is actually treated as cxt (e.g. args[0]._dtype, args[0]._fwd_used_autocast), which is no longer available after we remove the deprecated cxt from forward method. We can re-implement these few lines in setup_context like below,
@staticmethod
def setup_context(ctx, inputs, output):
device_type = get_accelerator().device_name()
ctx._dtype = torch.get_autocast_dtype(device_type)
ctx._fwd_used_autocast = torch.is_autocast_enabled(device_type)
input, weight, bias = inputs[0], inputs[1], inputs[2] if len(inputs) > 2 else None
ctx.save_for_backward(input, weight, bias)
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
… unpack error Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
252aea1 to
39b1755
Compare
| @staticmethod | ||
| def setup_context(ctx, inputs, output): | ||
| input, weight, bias = inputs | ||
| ctx.save_for_backward(input, weight, bias) |
There was a problem hiding this comment.
Following up on codex's commet, I suggest for Pytorch >=2.0, for forward, this line should be removed, @autocast_custom_fwd. So forward will look like this,
@staticmethod
# bias is an optional argument
def forward(input, weight, bias=None):
if input.dim() == 2 and bias is not None:
# fused op is marginally faster
ret = torch.addmm(bias, input, weight.t())
else:
output = input.matmul(weight.t())
if bias is not None:
output += bias
ret = output
return ret
Reason being, according to line 35, @autocast_custom_fwd is alias of torch.amp.autocast_mode.py:custom_fwd. Its nested decorate_fwd method looks like this:
@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
args[0]._dtype = torch.get_autocast_dtype(device_type)
if cast_inputs is None:
args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
return fwd(*args, **kwargs)
else:
...
return decorate_fwd
the args[0] is actually treated as cxt (e.g. args[0]._dtype, args[0]._fwd_used_autocast), which is no longer available after we remove the deprecated cxt from forward method. We can re-implement these few lines in setup_context like below,
@staticmethod
def setup_context(ctx, inputs, output):
device_type = get_accelerator().device_name()
ctx._dtype = torch.get_autocast_dtype(device_type)
ctx._fwd_used_autocast = torch.is_autocast_enabled(device_type)
input, weight, bias = inputs[0], inputs[1], inputs[2] if len(inputs) > 2 else None
ctx.save_for_backward(input, weight, bias)
|
|
||
| # PyTorch >= 2.0 supports setup_context, which is required for | ||
| # torch.func transforms (vmap, grad, jvp, jacrev, etc.) | ||
| _SUPPORTS_SETUP_CONTEXT = hasattr(torch.autograd.Function, 'setup_context') |
There was a problem hiding this comment.
I tried to implement a version check too, but cannot find a way to avoid duplicated code. What do you think about adding a error log like below,
def _require_pt2_zero3_linear():
if int(torch.__version__.split(".")[0]) < 2:
raise RuntimeError("LinearFunctionForZeroStage3 requires PyTorch >= 2.0")
and use it in LinearModuleForZeroStage3's forward method?
For pytorch 1.9, we will just deprecate it by adding a new line in README.md in the requirements section,
For ZeRO Stage 3, `PyTorch >= 2.0` is required, as parameter-offload hooks rely on `torch.autograd.Function.setup_context`.
There was a problem hiding this comment.
Good catch! You're right.
Also, I agree the version branch duplication isn't ideal. However, I think dropping PyTorch < 2.0 support is a policy decision that should come from the maintainers
|
Hey @roycho96, I used this script to reproduce your issue #7913. After I implement the fixes in I am not authorised to push directly into your branch. In case I blocked your PR, here is my fix for |
Hi @zhangj1an, I've sent you a collaborator invite to my fork. Feel free to push your fix directly to the branch. Thanks for the suggestion! |
… setup_context Co-authored-by: zhangj1an <jianmusings@gmail.com> Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
444122c to
6df37af
Compare
…afe linear Avoid asymmetric custom_bwd without custom_fwd on the setup_context forward path; mirror forward AMP in backward via torch.amp.autocast. Signed-off-by: Zhang <jianmusings@gmail.com>
PyTorch versions that expose autograd.Function.setup_context need the modern forward + setup_context shape for torch.func / functorch. Signed-off-by: Zhang <jianmusings@gmail.com>
0a66444 to
5e83d05
Compare
|
Hey @roycho96 , thanks for the branch invite! I finished the fix required from my side. summary of fix
Update: found a verification script. My previous script did not go through |
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
|
Hey @roycho96, I think this PR is ready for review! Below is my attempt for the updated PR description. Feel free to make any changes. SummaryFixes #7913 ZeRO-3’s custom linear This breaks real stacks that combine ZeRO-3 with functorch-based code (e.g. some Liger-Kernel / Axolotl KD paths). More details in the issue above. Fix
When How to verifyThe bug can be reproduced on Two temporary scripts are added to help reviewers compare this branch vs
How to run./scripts/setup_pr7916.sh
# Optional: wipe venv and reinstall from scratch
./scripts/setup_pr7916.sh --force-installAbove script will print result from both branches. Alternatively, if the venv is already set up, you can run source .venvs/pr7916/bin/activate
torchrun --standalone --nproc_per_node=1 scripts/repro_pr7916.pyExpected output
Test plan
Cleanup
Example environment where the bug was reproduced
|
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 60d20da79f
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if getattr(ctx, "_fwd_used_autocast", False): | ||
| with torch.amp.autocast(device_type=device_type, enabled=True, dtype=ctx._dtype): | ||
| return LinearFunctionForZeroStage3._backward_core(ctx, grad_output) | ||
| return LinearFunctionForZeroStage3._backward_core(ctx, grad_output) |
There was a problem hiding this comment.
Restore backward autocast state for both true and false cases
Fresh evidence in this commit: the new setup_context path only enters an autocast context when ctx._fwd_used_autocast is true, and otherwise runs backward without forcing autocast off. @custom_bwd semantics are to run backward with enabled=ctx._fwd_used_autocast, so if backward executes inside an outer autocast-enabled region while forward ran without autocast, gradients are now computed under the wrong autocast state compared with the legacy path.
Useful? React with 👍 / 👎.
LinearFunctionForZeroStage3uses the legacyforward(ctx, ...)pattern which is incompatible withtorch.functransforms (torch.func.grad,torch.func.grad_and_value,vmap, etc.):This affects any library that uses
torch.funcinternally on a ZeRO-3 model.Fix
Fixes #7913
Note
As pointed out by @zhangj1an in #7913,
PostBackwardFunctionModuleandPreBackwardFunctionForModuleinparameter_offload.pyhave the same issue. Those will be addressed in a follow-up commit within this PR.