-
Notifications
You must be signed in to change notification settings - Fork 872
[OpenVINO][Examples] Add Quantization for the OpenVINO Stable Diffusion Example #17807
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
056ed58
810214d
a7a41e3
135fd60
a164209
1b605aa
ab09c86
a291ed8
b27ef24
5463abe
d457d76
0ad4a25
e065b1b
08111fe
d28c5cb
6b7eada
5d98f9e
9d7b0d7
d6b8933
d6db584
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,15 +10,25 @@ | |
| import logging | ||
| import os | ||
|
|
||
| import datasets # type: ignore[import-untyped] | ||
| import nncf # type: ignore[import-untyped] | ||
|
|
||
| import torch | ||
|
|
||
| from executorch.backends.openvino.partitioner import OpenvinoPartitioner | ||
| from executorch.backends.openvino.quantizer import ( | ||
| OpenVINOQuantizer, | ||
| QuantizationMode, | ||
| quantize_model, | ||
| ) | ||
| from executorch.examples.models.stable_diffusion.model import ( # type: ignore[import-untyped] | ||
| LCMModelLoader, | ||
| StableDiffusionComponent, | ||
| ) | ||
| from executorch.exir import ExecutorchBackendConfig, to_edge_transform_and_lower | ||
| from executorch.exir.backend.backend_details import CompileSpec | ||
| from torch.export import export | ||
| from tqdm import tqdm # type: ignore[import-untyped] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is the type ignore here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was suggested by the lintrunner |
||
|
|
||
| # Configure logging | ||
| logging.basicConfig(level=logging.INFO) | ||
|
|
@@ -31,27 +41,180 @@ class LCMOpenVINOExporter: | |
| def __init__( | ||
| self, | ||
| model_id: str = "SimianLuo/LCM_Dreamshaper_v7", | ||
| is_quantization_enabled: bool = False, | ||
| dtype: torch.dtype = torch.float16, | ||
| calibration_dataset_name: str = "google-research-datasets/conceptual_captions", | ||
| calibration_dataset_column: str = "caption", | ||
| ): | ||
| if is_quantization_enabled: | ||
| dtype = torch.float32 | ||
| self.is_quantization_enabled = is_quantization_enabled | ||
| self.calibration_dataset_name = calibration_dataset_name | ||
| self.calibration_dataset_column = calibration_dataset_column | ||
| self.model_loader = LCMModelLoader(model_id=model_id, dtype=dtype) | ||
|
|
||
| def load_models(self) -> bool: | ||
| """Load the LCM pipeline and extract components""" | ||
| return self.model_loader.load_models() | ||
|
|
||
| @staticmethod | ||
| def get_unet_calibration_dataset( | ||
| pipeline, | ||
| dataset_name: str, | ||
| dataset_column: str, | ||
| calibration_dataset_size: int = 200, | ||
| num_inference_steps: int = 4, | ||
| ) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: | ||
| """Collect UNet calibration inputs from prompts.""" | ||
|
|
||
| class UNetWrapper(torch.nn.Module): | ||
| def __init__(self, model: torch.nn.Module, config): | ||
| super().__init__() | ||
| self.model = model | ||
| self.config = config | ||
| self.captured_args: list[ | ||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor] | ||
| ] = [] | ||
|
|
||
| def _pick_correct_arg_or_kwarg( | ||
| self, | ||
| name: str, | ||
| args, | ||
| kwargs, | ||
| idx: int, | ||
| ): | ||
| if name in kwargs and kwargs[name] is not None: | ||
| return kwargs[name] | ||
| if len(args) > idx: | ||
| return args[idx] | ||
| raise KeyError(f"Missing required UNet input: {name}") | ||
|
|
||
| def _process_inputs( | ||
| self, *args, **kwargs | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| sample = self._pick_correct_arg_or_kwarg("sample", args, kwargs, 0) | ||
| timestep = self._pick_correct_arg_or_kwarg("timestep", args, kwargs, 1) | ||
| encoder_hidden_states = self._pick_correct_arg_or_kwarg( | ||
| "encoder_hidden_states", args, kwargs, 2 | ||
| ) | ||
| timestep = ( | ||
| timestep.unsqueeze(0) | ||
| if isinstance(timestep, torch.Tensor) and timestep.dim() == 0 | ||
| else timestep | ||
| ) | ||
| processed_args = ( | ||
| sample, | ||
| timestep, | ||
| encoder_hidden_states, | ||
| ) | ||
| return processed_args | ||
|
|
||
| def forward(self, *args, **kwargs): | ||
| """ | ||
| Obtain and pass each input individually to ensure the order is maintained | ||
| and the right values are being passed according to the expected inputs by | ||
| the OpenVINO LCM runner. | ||
| """ | ||
| unet_args = self._process_inputs(*args, **kwargs) | ||
| self.captured_args.append(unet_args) | ||
| return self.model(*args, **kwargs) | ||
|
|
||
| calibration_data = [] | ||
| dataset = datasets.load_dataset( | ||
| dataset_name, | ||
| split="train", | ||
anzr299 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| streaming=True, | ||
| ).shuffle(seed=42) | ||
| original_unet = pipeline.unet | ||
| wrapped_unet = UNetWrapper(pipeline.unet, pipeline.unet.config) | ||
| pipeline.unet = wrapped_unet | ||
| # Run inference for data collection | ||
| pbar = tqdm(total=calibration_dataset_size) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe executorch has some sort of progress bar already? The less dependencies the better
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tqdm is used in multiple places inside executorch examples too. |
||
| try: | ||
| for batch in dataset: | ||
| if dataset_column not in batch: | ||
| raise RuntimeError( | ||
| f"Column '{dataset_column}' was not found in dataset '{dataset_name}'" | ||
| ) | ||
| prompt = batch[dataset_column] | ||
| tokenized_prompt = pipeline.tokenizer.encode(prompt) | ||
| if len(tokenized_prompt) > pipeline.tokenizer.model_max_length: | ||
| continue | ||
| # Run the pipeline | ||
| pipeline( | ||
| prompt, | ||
| num_inference_steps=num_inference_steps, | ||
| height=512, | ||
| width=512, | ||
| output_type="latent", | ||
| ) | ||
| calibration_data.extend(wrapped_unet.captured_args) | ||
| wrapped_unet.captured_args = [] | ||
| pbar.update(len(calibration_data) - pbar.n) | ||
| if pbar.n >= calibration_dataset_size: | ||
| break | ||
anzr299 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| finally: | ||
| pipeline.unet = original_unet | ||
| pbar.close() | ||
| return calibration_data | ||
|
|
||
| def quantize_unet_model( | ||
| self, | ||
| model: torch.export.ExportedProgram, | ||
| dummy_inputs, | ||
| ) -> torch.export.ExportedProgram: | ||
| """Quantize UNet using activation-aware PTQ.""" | ||
| pipeline = self.model_loader.pipeline | ||
| calibration_dataset = self.get_unet_calibration_dataset( | ||
| pipeline, | ||
| self.calibration_dataset_name, | ||
| self.calibration_dataset_column, | ||
| ) | ||
| model = model.module() | ||
| quantized_model = quantize_model( | ||
| model, | ||
| mode=QuantizationMode.INT8_TRANSFORMER, | ||
| calibration_dataset=calibration_dataset, | ||
| smooth_quant=True, | ||
| ) | ||
| # Re-export the transformed torch.fx.GraphModule to ExportedProgram | ||
| quantized_exported_program = export(quantized_model, dummy_inputs) | ||
| return quantized_exported_program | ||
|
|
||
| @staticmethod | ||
| def compress_model( | ||
| model: torch.export.ExportedProgram, | ||
| dummy_inputs, | ||
| ) -> torch.export.ExportedProgram: | ||
| """Apply weights-only compression for non-UNet components.""" | ||
| model = model.module() | ||
| ov_quantizer = OpenVINOQuantizer(mode=QuantizationMode.INT8WO_ASYM) | ||
| quantized_model = nncf.experimental.torch.fx.compress_pt2e( | ||
| model, quantizer=ov_quantizer | ||
| ) | ||
| # Re-export the transformed torch.fx.GraphModule to ExportedProgram | ||
| quantized_exported_program = export(quantized_model, dummy_inputs) | ||
| return quantized_exported_program | ||
|
|
||
| def export_text_encoder(self, output_path: str, device: str = "CPU") -> bool: | ||
| """Export CLIP text encoder to PTE file""" | ||
| try: | ||
| logger.info("Exporting text encoder with OpenVINO backend...") | ||
|
|
||
| sd_model_component = StableDiffusionComponent.TEXT_ENCODER | ||
|
|
||
| # Get wrapped model and dummy inputs | ||
| text_encoder_wrapper = self.model_loader.get_text_encoder_wrapper() | ||
| dummy_inputs = self.model_loader.get_dummy_inputs() | ||
|
|
||
| # Export to ATEN graph | ||
| exported_program = export( | ||
| text_encoder_wrapper, dummy_inputs["text_encoder"] | ||
| ) | ||
| component_dummy_inputs = dummy_inputs[sd_model_component] | ||
| exported_program = export(text_encoder_wrapper, component_dummy_inputs) | ||
|
|
||
| if self.is_quantization_enabled: | ||
| exported_program = self.compress_model( | ||
| exported_program, component_dummy_inputs | ||
| ) | ||
|
|
||
| # Configure OpenVINO compilation | ||
| compile_spec = [CompileSpec("device", device.encode())] | ||
|
|
@@ -85,13 +248,20 @@ def export_unet(self, output_path: str, device: str = "CPU") -> bool: | |
| """Export UNet model to PTE file""" | ||
| try: | ||
| logger.info("Exporting UNet model with OpenVINO backend...") | ||
| sd_model_component = StableDiffusionComponent.UNET | ||
|
|
||
| # Get wrapped model and dummy inputs | ||
| unet_wrapper = self.model_loader.get_unet_wrapper() | ||
| dummy_inputs = self.model_loader.get_dummy_inputs() | ||
|
|
||
| # Export to ATEN graph | ||
| exported_program = export(unet_wrapper, dummy_inputs["unet"]) | ||
| component_dummy_inputs = dummy_inputs[sd_model_component] | ||
| exported_program = export(unet_wrapper, component_dummy_inputs) | ||
|
|
||
| if self.is_quantization_enabled: | ||
| exported_program = self.quantize_unet_model( | ||
| exported_program, component_dummy_inputs | ||
| ) | ||
|
|
||
| # Configure OpenVINO compilation | ||
| compile_spec = [CompileSpec("device", device.encode())] | ||
|
|
@@ -125,13 +295,20 @@ def export_vae_decoder(self, output_path: str, device: str = "CPU") -> bool: | |
| """Export VAE decoder to PTE file""" | ||
| try: | ||
| logger.info("Exporting VAE decoder with OpenVINO backend...") | ||
| sd_model_component = StableDiffusionComponent.VAE_DECODER | ||
|
|
||
| # Get wrapped model and dummy inputs | ||
| vae_decoder = self.model_loader.get_vae_decoder() | ||
| dummy_inputs = self.model_loader.get_dummy_inputs() | ||
|
|
||
| # Export to ATEN graph | ||
| exported_program = export(vae_decoder, dummy_inputs["vae_decoder"]) | ||
| component_dummy_inputs = dummy_inputs[sd_model_component] | ||
| exported_program = export(vae_decoder, component_dummy_inputs) | ||
|
|
||
| if self.is_quantization_enabled: | ||
| exported_program = self.compress_model( | ||
| exported_program, component_dummy_inputs | ||
| ) | ||
|
|
||
| # Configure OpenVINO compilation | ||
| compile_spec = [CompileSpec("device", device.encode())] | ||
|
|
@@ -223,9 +400,23 @@ def create_argument_parser(): | |
|
|
||
| parser.add_argument( | ||
| "--dtype", | ||
| choices=["fp16", "fp32"], | ||
| choices=["fp16", "fp32", "int8"], | ||
| default="fp16", | ||
| help="Model data type (default: fp16)", | ||
| help="Model data type. Use int8 to enable PTQ quantization (default: fp16)", | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--calibration_dataset_name", | ||
| type=str, | ||
| default="google-research-datasets/conceptual_captions", | ||
| help="HuggingFace dataset used for UNet calibration when INT8 quantization is enabled", | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--calibration_dataset_column", | ||
| type=str, | ||
| default="caption", | ||
| help="Dataset column name used as prompt text for UNet calibration", | ||
| ) | ||
|
|
||
| parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") | ||
|
|
@@ -249,11 +440,18 @@ def main() -> int: | |
| logger.info("=" * 60) | ||
|
|
||
| # Map dtype string to torch dtype | ||
| dtype_map = {"fp16": torch.float16, "fp32": torch.float32} | ||
| is_quantization_enabled = args.dtype == "int8" | ||
| dtype_map = {"fp16": torch.float16, "fp32": torch.float32, "int8": torch.float32} | ||
| dtype = dtype_map[args.dtype] | ||
|
|
||
| # Create exporter and load models | ||
| exporter = LCMOpenVINOExporter(args.model_id, dtype=dtype) | ||
| exporter = LCMOpenVINOExporter( | ||
| args.model_id, | ||
| is_quantization_enabled=is_quantization_enabled, | ||
| dtype=dtype, | ||
| calibration_dataset_name=args.calibration_dataset_name, | ||
| calibration_dataset_column=args.calibration_dataset_column, | ||
| ) | ||
|
|
||
| if not exporter.load_models(): | ||
| logger.error("Failed to load models") | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.