Skip to content

[OpenVINO][Examples] Add Quantization for the OpenVINO Stable Diffusion Example#17807

Open
anzr299 wants to merge 20 commits intopytorch:mainfrom
anzr299:an/openvino/quantize_lcm_model
Open

[OpenVINO][Examples] Add Quantization for the OpenVINO Stable Diffusion Example#17807
anzr299 wants to merge 20 commits intopytorch:mainfrom
anzr299:an/openvino/quantize_lcm_model

Conversation

@anzr299
Copy link
Contributor

@anzr299 anzr299 commented Mar 3, 2026

Summary

Extend the stable diffusion example for OpenVINO backend with quantization support.

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 3, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17807

Note: Links to docs will display an error until the docs builds have been completed.

⚠️ 9 Awaiting Approval

As of commit d6db584 with merge base 40200e6 (image):

AWAITING APPROVAL - The following workflows need approval before CI can run:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 3, 2026
@github-actions
Copy link

github-actions bot commented Mar 3, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Contributor

@daniil-lyakhov daniil-lyakhov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general:

I think the maybe logic is not worth it there, and it would be nicer to have a separate quantize_unet and compress_model functions in each export function.

I mean now the diamond structure of export looks too complicated that it is in reality

from executorch.exir.backend.backend_details import CompileSpec
from torch.export import export
from torchao.quantization.pt2e.quantizer.quantizer import Quantizer
from tqdm import tqdm # type: ignore[import-untyped]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the type ignore here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was suggested by the lintrunner

Comment on lines +242 to +248
# Configure OpenVINO compilation
compile_spec = [CompileSpec("device", device.encode())]
partitioner = OpenvinoPartitioner(compile_spec)

# Lower to edge dialect and apply OpenVINO backend
edge_manager = to_edge_transform_and_lower(
exported_program, partitioner=[partitioner]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dublicate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes. Great catch!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +175 to +176
if not is_quantization_enabled:
return model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep only code which could raise an error inside of the try catch block

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +180 to +194
# Quantize activations for the Unet Model. Other models are weights-only quantized.
pipeline = self.model_loader.pipeline
try:
# We need the models in FP32 to run inference for calibration data collection
self._set_pipeline_dtype(pipeline, torch.float32)
calibration_dataset = self.get_unet_calibration_dataset(pipeline)
finally:
self._set_pipeline_dtype(pipeline, self.model_loader.dtype)

quantized_model = quantize_model(
model,
mode=QuantizationMode.INT8_TRANSFORMER,
calibration_dataset=calibration_dataset,
smooth_quant=True,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if body is so big it worth to split the function on two like quantize and compress

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed the try-finally. Now it is just quantize and compress

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about separate functions quantize and compress?


def forward(self, *args, **kwargs):
"""
obtain and pass each input individually to ensure the order is maintained
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
obtain and pass each input individually to ensure the order is maintained
Obtain and pass each input individually to ensure the order is maintained

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +141 to +145
dataset = datasets.load_dataset(
"google-research-datasets/conceptual_captions",
split="train",
trust_remote_code=True,
).shuffle(seed=42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe put the dataset name as an example param?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

wrapped_unet = UNetWrapper(pipeline.unet, pipeline.unet.config)
pipeline.unet = wrapped_unet
# Run inference for data collection
pbar = tqdm(total=calibration_dataset_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe executorch has some sort of progress bar already? The less dependencies the better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tqdm is used in multiple places inside executorch examples too.

Comment on lines +179 to +187
if self.should_quantize_model(sd_model_component):
# Quantize activations for the Unet Model. Other models are weights-only quantized.
pipeline = self.model_loader.pipeline
try:
# We need the models in FP32 to run inference for calibration data collection
self._set_pipeline_dtype(pipeline, torch.float32)
calibration_dataset = self.get_unet_calibration_dataset(pipeline)
finally:
self._set_pipeline_dtype(pipeline, self.model_loader.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If some condition then calibration dataset is set for stable diffusion, don't see value in the should_quantize_model method

@anzr299 anzr299 requested a review from daniil-lyakhov March 4, 2026 12:56
Comment on lines +180 to +194
# Quantize activations for the Unet Model. Other models are weights-only quantized.
pipeline = self.model_loader.pipeline
try:
# We need the models in FP32 to run inference for calibration data collection
self._set_pipeline_dtype(pipeline, torch.float32)
calibration_dataset = self.get_unet_calibration_dataset(pipeline)
finally:
self._set_pipeline_dtype(pipeline, self.model_loader.dtype)

quantized_model = quantize_model(
model,
mode=QuantizationMode.INT8_TRANSFORMER,
calibration_dataset=calibration_dataset,
smooth_quant=True,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about separate functions quantize and compress?

Comment on lines +314 to +319
exported_program = self._export_and_maybe_quantize(
vae_decoder,
dummy_inputs[sd_model_component],
sd_model_component,
self.is_quantization_enabled,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

export_program = self._export(...)
if quantize:
    exported_model = quantize(exported_model)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay, you mean separating the export and quantize logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding #17807 (comment)
The quantization is already 1 seperate function. Do you mean a function inside this file which collects calibration dataset and quantizes both?

Regarding compression, sure I will move it to a seperate function. I thought I would show the changes(removal of try-finally removes bulk of code there)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A PTQ fn for Unet and a WC function for the other parts, don't see a reason why the WC and PTQ should be united under one single function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. I made it into a single function which performs quantization or compression depending on the model.

It seemed more reasonale because for the user both are quantization(compression == weights only quantization) since compression is relatively more nncf term

@anzr299 anzr299 requested a review from daniil-lyakhov March 10, 2026 16:41
exported_program_module
)
# Re-export the transformed torch.fx.GraphModule to ExportedProgram
exported_program = export(exported_program_module, component_dummy_inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can put it to the quantization fn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

)

@staticmethod
def _compress_non_unet_model(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _compress_non_unet_model(
def compress_model(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

pipeline.unet = original_unet
return calibration_data

def _quantize_unet_model(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _quantize_unet_model(
def quantize_unet_model(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +138 to +140
prompt = batch[dataset_column]
if not isinstance(prompt, str):
prompt = str(prompt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prompt = batch[dataset_column]
if not isinstance(prompt, str):
prompt = str(prompt)
prompt = str(batch[dataset_column])

Should work

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright

prompt = batch[dataset_column]
if not isinstance(prompt, str):
prompt = str(prompt)
if len(prompt.split()) > pipeline.tokenizer.model_max_length:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You sure num tokens and num of spaces are always equal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@anzr299 anzr299 marked this pull request as ready for review March 14, 2026 10:25
@anzr299 anzr299 requested a review from lucylq as a code owner March 14, 2026 10:25
Copilot AI review requested due to automatic review settings March 14, 2026 10:25
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Extends the OpenVINO Stable Diffusion (LCM) example to support INT8 post-training quantization (PTQ) during export, and exposes the new dtype option in both export and inference CLIs.

Changes:

  • Add --dtype int8 path in export_lcm.py with UNet calibration + quantization and weights-only compression for other components.
  • Introduce StableDiffusionComponent enum to avoid stringly-typed component keys for dummy inputs.
  • Update example docs and dependencies to reflect quantization usage.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
examples/openvino/stable_diffusion/requirements.txt Adds tqdm dependency for calibration progress reporting.
examples/openvino/stable_diffusion/openvino_lcm.py Extends dtype CLI choices to include int8.
examples/openvino/stable_diffusion/export_lcm.py Implements PTQ export flow, calibration dataset handling, and component-key refactor.
examples/openvino/stable_diffusion/README.md Documents INT8 export and inference usage.
examples/models/stable_diffusion/model.py Adds StableDiffusionComponent enum and switches dummy input dict keys to the enum.
examples/models/stable_diffusion/init.py Re-exports the new StableDiffusionComponent.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Copilot AI review requested due to automatic review settings March 14, 2026 10:37
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the OpenVINO Stable Diffusion (LCM) example to support PTQ quantization (int8) during export and adds CLI/README updates to expose the new workflow.

Changes:

  • Add --dtype int8 option for export and inference scripts.
  • Implement PTQ quantization for UNet (activation-aware) and weights-only compression for other components during export.
  • Introduce StableDiffusionComponent enum to make component naming/lookup more robust.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
examples/openvino/stable_diffusion/requirements.txt Adds an extra dependency for the example environment.
examples/openvino/stable_diffusion/openvino_lcm.py Allows --dtype int8 in the inference CLI.
examples/openvino/stable_diffusion/export_lcm.py Implements quantization + calibration dataset collection and wires it into export.
examples/openvino/stable_diffusion/README.md Documents the new int8 export/inference path.
examples/models/stable_diffusion/model.py Adds StableDiffusionComponent enum and uses it for dummy-input keys.
examples/models/stable_diffusion/init.py Re-exports the new enum.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants