Skip to content

optimizer: Prevent constant folding of DynamicQuantizeLinear#2865

Merged
justinchuby merged 2 commits intomainfrom
copilot/fix-dequantize-linear-folding-issue
Mar 26, 2026
Merged

optimizer: Prevent constant folding of DynamicQuantizeLinear#2865
justinchuby merged 2 commits intomainfrom
copilot/fix-dequantize-linear-folding-issue

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented Mar 25, 2026

The constant folding pass was eliminating DequantizeLinear nodes that operated on constant weight tensors during optimize(), collapsing the quantization structure into a plain Conv and losing quantization semantics in QAT-exported models.

Changes

  • optimizer/_constant_folding.py: Add DynamicQuantizeLinear to DEFAULT_CONSTANT_FOLD_BLACKLIST alongside the existing QuantizeLinear and DequantizeLinear entries; reorder alphabetically for consistency
  • optimizer/_constant_folding_test.py: Add tests verifying QuantizeLinear and DequantizeLinear are not folded when all inputs are constant initializers
Original prompt

This section details on the original issue you should resolve

<issue_title>[ONNX] Optimize should not fold DequantizeLinear</issue_title>
<issue_description>### 🐛 Describe the bug

After the QAT model undergoes the onnx_program.optimize() process, there is a loss of quantization nodes. As shown in the figure on the left is the normal export, and on the right is the abnormal export graph.

Image

This bug occurred in torch/onnx/_internal/exporter/_onnx_program.py:

def optimize(self) -> None:
    self.model = onnxscript_apis.optimize(self.model)

and it internally called the optimize_ir function in onnxscript/optimizer/_optimizer.py.
The default value of input_size_limit is 512. Nodes with an input size less than this value will be collapsed.

def optimize_ir(
    model: ir.Model,
    num_iterations: int = 2,
    *,
    onnx_shape_inference: bool = True,
    stop_if_no_change: bool = True,
    input_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
    output_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
    inline: bool = True,
) -> None:
    passes = [
        ir.passes.PassManager(
            [
                _constant_folding.FoldConstantsPass(
                    shape_inference=onnx_shape_inference,
                    input_size_limit=input_size_limit,
                    output_size_limit=output_size_limit,
                ),
                rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES),
                common_passes.RemoveUnusedNodesPass(),
                common_passes.RemoveUnusedFunctionsPass(),
                common_passes.RemoveUnusedOpsetsPass(),
            ],
            steps=num_iterations,
            early_stop=stop_if_no_change,
        ),
    ......

⭐ Please enable the parameter optimization function in torch/onnx/_internal/exporter/_onnx_program.py. Otherwise, I will be able to install onnxscript only by referring to the source code.

The smallest reproducible example:

import copy
import torch
import torch.nn as nn
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e, convert_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)
from onnxslim import slim
import onnx


class ConvBnReluModel(nn.Module):
    def __init__(self, eps=1e-3, momentum=0.03):
        super().__init__()
        self.conv = nn.Conv2d(4, 4, 3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(4, eps=eps, momentum=momentum)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


def get_batch_norm_node_args(gm):
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.batch_norm.default:
            return tuple(node.args)
    raise RuntimeError("No aten.batch_norm.default node found")


torch.manual_seed(0)
device = 'cuda' 

model = ConvBnReluModel().train().to(device)
inputs = (torch.randn(2, 4, 8, 8).to(device),)
exported = torch.export.export_for_training(copy.deepcopy(model), inputs).module()
print("before prepare:", get_batch_norm_node_args(exported))

quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_qat=True))
prepared = prepare_qat_pt2e(exported, quantizer)
prepared.to(device)
torch.ao.quantization.move_exported_model_to_eval(prepared)
torch.ao.quantization.allow_exported_model_train_eval(prepared)
prepared.eval()

#---export quantized model to onnx---
qat_onnx_sp = './quant.onnx'
quantized_model = convert_pt2e(prepared)
print('convert_pt2e done!')
onnx_program = torch.onnx.export(quantized_model, inputs, dynamo=True, opset_version=21)
""" bug """
onnx_program.optimize()
onnx_program.save(qat_onnx_sp)
print(f'export qat model to [{qat_onnx_sp}] done!')

model_simp = slim(onnx_program.model_proto)
sim_path = qat_onnx_sp.replace('.onnx', '_slim.onnx')
onnx.save(model_simp, sim_path)
print(f"save onnx model to [{sim_path}] Successfully!")

Versions

Versions of relevant libraries:
[pip3] executorch==0.5.0
[pip3] numpy==1.23.5
[pip3] nvidia-cublas-cu11==11.11.3.6
[pip3] nvidia-cuda-cupti-cu11==11.8.87
[pip3] nvidia-cuda-nvrtc-cu11==11.8.89
[pip3] nvidia-cuda-runtime-cu11==11.8.89
[pip3] nvidia-cudnn-cu11==9.1.0.70
[pip3] nvidia-cufft-cu11==10.9.0.58
[pip3] nvidia-curand-cu11==10.3.0.86
[pip3] nvidia-cusolver-cu11==11.4.1.48
[pip3] nvidia-cusparse-cu11==11.7.5.86
[pip3] nvidia-nccl-cu11==2.21.5
[pip3] nvidia-nvtx-cu11==11.8.86
[pip3] onnx==1.17.0
[pip3] onnx_graphsurgeon==0.5.8
[pip3] onnx-ir==0.1.12
[pip3] onnx-simplifier==0.4.36
[pip3] onnxruntime==1.21.0
[pip3] onnxruntime-gpu==1.21.0
[pip3] onnxscript==0.4.0
[pip3] onnxslim==0.1.48
[pip3] torch==2.6.0+cu118
[pip3] torchao==0.14.1
[pip3] torchaudio==2.6.0+cu118
[pip3] torchvision==0.21.0+cu118
[pip3] ...


📍 Connect Copilot coding agent with Jira, Azure Boards or Linear to delegate work to Copilot in one click without leaving your project management tool.

Copilot AI changed the title [WIP] Fix optimization folding issue for DequantizeLinear in ONNX optimizer: Prevent constant folding of quantization ops (QuantizeLinear, DequantizeLinear, DynamicQuantizeLinear) Mar 25, 2026
Copilot AI requested a review from justinchuby March 25, 2026 23:48
@justinchuby
Copy link
Copy Markdown
Collaborator

@copilot strange: DequantizeLinear was in the list, why was it still folded in the original issue?

Copy link
Copy Markdown
Contributor Author

Copilot AI commented Mar 26, 2026

@copilot strange: DequantizeLinear was in the list, why was it still folded in the original issue?

The DequantizeLinear and QuantizeLinear entries were added to the blacklist in commit bbe9c2bc (PR #2713), but that commit was not included in the v0.4.0 release (confirmed via git merge-base --is-ancestor). The issue reporter was using onnxscript==0.4.0, which is why they saw the folding behavior.

My change adds DynamicQuantizeLinear to complete the list of quantization ops that should be protected from folding in the current (post-bbe9c2bc) codebase.

@justinchuby justinchuby marked this pull request as ready for review March 26, 2026 18:19
@justinchuby justinchuby changed the title optimizer: Prevent constant folding of quantization ops (QuantizeLinear, DequantizeLinear, DynamicQuantizeLinear) optimizer: Prevent constant folding of DynamicQuantizeLinear Mar 26, 2026
@justinchuby justinchuby enabled auto-merge (squash) March 26, 2026 18:24
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 26, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 71.87%. Comparing base (19e5284) to head (a78b1cc).
⚠️ Report is 1 commits behind head on main.
✅ All tests successful. No failed tests found.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2865      +/-   ##
==========================================
+ Coverage   71.86%   71.87%   +0.01%     
==========================================
  Files         239      239              
  Lines       29138    29150      +12     
  Branches     2875     2875              
==========================================
+ Hits        20941    20953      +12     
  Misses       7219     7219              
  Partials      978      978              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@justinchuby justinchuby merged commit 847801c into main Mar 26, 2026
32 of 35 checks passed
@justinchuby justinchuby deleted the copilot/fix-dequantize-linear-folding-issue branch March 26, 2026 18:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Development

Successfully merging this pull request may close these issues.

[ONNX] Optimize should not fold DequantizeLinear

3 participants