Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions examples/models/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@
# except in compliance with the License. See the license file found in the
# LICENSE file in the root directory of this source tree.

from .model import LCMModelLoader, TextEncoderWrapper, UNetWrapper, VAEDecoder
from .model import (
LCMModelLoader,
StableDiffusionComponent,
TextEncoderWrapper,
UNetWrapper,
VAEDecoder,
)

__all__ = ["LCMModelLoader", "TextEncoderWrapper", "UNetWrapper", "VAEDecoder"]
__all__ = [
"LCMModelLoader",
"StableDiffusionComponent",
"TextEncoderWrapper",
"UNetWrapper",
"VAEDecoder",
]
17 changes: 13 additions & 4 deletions examples/models/stable_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""

import logging
from enum import Enum
from typing import Any, Optional

import torch
Expand All @@ -26,6 +27,14 @@
logger = logging.getLogger(__name__)


class StableDiffusionComponent(Enum):
"""Maintain Stable Diffusion model component names reliably"""

TEXT_ENCODER = "text_encoder"
UNET = "unet"
VAE_DECODER = "vae_decoder"


class TextEncoderWrapper(torch.nn.Module):
"""Wrapper for CLIP text encoder that extracts last_hidden_state"""

Expand Down Expand Up @@ -150,7 +159,7 @@ def get_vae_decoder(self) -> VAEDecoder:
raise ValueError("Models not loaded. Call load_models() first.")
return VAEDecoder(self.vae)

def get_dummy_inputs(self):
def get_dummy_inputs(self) -> dict[StableDiffusionComponent, tuple[Any, ...]]:
"""
Get dummy inputs for each model component.

Expand Down Expand Up @@ -187,7 +196,7 @@ def get_dummy_inputs(self):
vae_input = torch.randn(1, 4, 64, 64, dtype=self.dtype)

return {
"text_encoder": (text_encoder_input,),
"unet": unet_inputs,
"vae_decoder": (vae_input,),
StableDiffusionComponent.TEXT_ENCODER: (text_encoder_input,),
StableDiffusionComponent.UNET: unet_inputs,
StableDiffusionComponent.VAE_DECODER: (vae_input,),
}
11 changes: 11 additions & 0 deletions examples/openvino/stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ python export_lcm.py \
--device CPU \
--dtype fp16
```

To quantize the UNet with 8-bit activations and 8-bit weights (8a8w) and apply weights-only 8-bit quantization (16a8w) to the remaining components, run:
```bash
python export_lcm.py \
--model_id SimianLuo/LCM_Dreamshaper_v7 \
--output_dir ./lcm_models \
--device CPU \
--dtype int8
```

This will create three files in `./lcm_models/`:
- `text_encoder.pte`
- `unet.pte`
Expand All @@ -33,6 +43,7 @@ This will create three files in `./lcm_models/`:
### Generate Images

Run inference with the exported model:
Note: For quantized models, pass `--dtype int8`

```bash
python openvino_lcm.py \
Expand Down
216 changes: 207 additions & 9 deletions examples/openvino/stable_diffusion/export_lcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,25 @@
import logging
import os

import datasets # type: ignore[import-untyped]
import nncf # type: ignore[import-untyped]

import torch

from executorch.backends.openvino.partitioner import OpenvinoPartitioner
from executorch.backends.openvino.quantizer import (
OpenVINOQuantizer,
QuantizationMode,
quantize_model,
)
from executorch.examples.models.stable_diffusion.model import ( # type: ignore[import-untyped]
LCMModelLoader,
StableDiffusionComponent,
)
from executorch.exir import ExecutorchBackendConfig, to_edge_transform_and_lower
from executorch.exir.backend.backend_details import CompileSpec
from torch.export import export
from tqdm import tqdm # type: ignore[import-untyped]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the type ignore here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was suggested by the lintrunner


# Configure logging
logging.basicConfig(level=logging.INFO)
Expand All @@ -31,27 +41,180 @@ class LCMOpenVINOExporter:
def __init__(
self,
model_id: str = "SimianLuo/LCM_Dreamshaper_v7",
is_quantization_enabled: bool = False,
dtype: torch.dtype = torch.float16,
calibration_dataset_name: str = "google-research-datasets/conceptual_captions",
calibration_dataset_column: str = "caption",
):
if is_quantization_enabled:
dtype = torch.float32
self.is_quantization_enabled = is_quantization_enabled
self.calibration_dataset_name = calibration_dataset_name
self.calibration_dataset_column = calibration_dataset_column
self.model_loader = LCMModelLoader(model_id=model_id, dtype=dtype)

def load_models(self) -> bool:
"""Load the LCM pipeline and extract components"""
return self.model_loader.load_models()

@staticmethod
def get_unet_calibration_dataset(
pipeline,
dataset_name: str,
dataset_column: str,
calibration_dataset_size: int = 200,
num_inference_steps: int = 4,
) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Collect UNet calibration inputs from prompts."""

class UNetWrapper(torch.nn.Module):
def __init__(self, model: torch.nn.Module, config):
super().__init__()
self.model = model
self.config = config
self.captured_args: list[
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
] = []

def _pick_correct_arg_or_kwarg(
self,
name: str,
args,
kwargs,
idx: int,
):
if name in kwargs and kwargs[name] is not None:
return kwargs[name]
if len(args) > idx:
return args[idx]
raise KeyError(f"Missing required UNet input: {name}")

def _process_inputs(
self, *args, **kwargs
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
sample = self._pick_correct_arg_or_kwarg("sample", args, kwargs, 0)
timestep = self._pick_correct_arg_or_kwarg("timestep", args, kwargs, 1)
encoder_hidden_states = self._pick_correct_arg_or_kwarg(
"encoder_hidden_states", args, kwargs, 2
)
timestep = (
timestep.unsqueeze(0)
if isinstance(timestep, torch.Tensor) and timestep.dim() == 0
else timestep
)
processed_args = (
sample,
timestep,
encoder_hidden_states,
)
return processed_args

def forward(self, *args, **kwargs):
"""
Obtain and pass each input individually to ensure the order is maintained
and the right values are being passed according to the expected inputs by
the OpenVINO LCM runner.
"""
unet_args = self._process_inputs(*args, **kwargs)
self.captured_args.append(unet_args)
return self.model(*args, **kwargs)

calibration_data = []
dataset = datasets.load_dataset(
dataset_name,
split="train",
streaming=True,
).shuffle(seed=42)
original_unet = pipeline.unet
wrapped_unet = UNetWrapper(pipeline.unet, pipeline.unet.config)
pipeline.unet = wrapped_unet
# Run inference for data collection
pbar = tqdm(total=calibration_dataset_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe executorch has some sort of progress bar already? The less dependencies the better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tqdm is used in multiple places inside executorch examples too.

try:
for batch in dataset:
if dataset_column not in batch:
raise RuntimeError(
f"Column '{dataset_column}' was not found in dataset '{dataset_name}'"
)
prompt = batch[dataset_column]
tokenized_prompt = pipeline.tokenizer.encode(prompt)
if len(tokenized_prompt) > pipeline.tokenizer.model_max_length:
continue
# Run the pipeline
pipeline(
prompt,
num_inference_steps=num_inference_steps,
height=512,
width=512,
output_type="latent",
)
calibration_data.extend(wrapped_unet.captured_args)
wrapped_unet.captured_args = []
pbar.update(len(calibration_data) - pbar.n)
if pbar.n >= calibration_dataset_size:
break
finally:
pipeline.unet = original_unet
pbar.close()
return calibration_data

def quantize_unet_model(
self,
model: torch.export.ExportedProgram,
dummy_inputs,
) -> torch.export.ExportedProgram:
"""Quantize UNet using activation-aware PTQ."""
pipeline = self.model_loader.pipeline
calibration_dataset = self.get_unet_calibration_dataset(
pipeline,
self.calibration_dataset_name,
self.calibration_dataset_column,
)
model = model.module()
quantized_model = quantize_model(
model,
mode=QuantizationMode.INT8_TRANSFORMER,
calibration_dataset=calibration_dataset,
smooth_quant=True,
)
# Re-export the transformed torch.fx.GraphModule to ExportedProgram
quantized_exported_program = export(quantized_model, dummy_inputs)
return quantized_exported_program

@staticmethod
def compress_model(
model: torch.export.ExportedProgram,
dummy_inputs,
) -> torch.export.ExportedProgram:
"""Apply weights-only compression for non-UNet components."""
model = model.module()
ov_quantizer = OpenVINOQuantizer(mode=QuantizationMode.INT8WO_ASYM)
quantized_model = nncf.experimental.torch.fx.compress_pt2e(
model, quantizer=ov_quantizer
)
# Re-export the transformed torch.fx.GraphModule to ExportedProgram
quantized_exported_program = export(quantized_model, dummy_inputs)
return quantized_exported_program

def export_text_encoder(self, output_path: str, device: str = "CPU") -> bool:
"""Export CLIP text encoder to PTE file"""
try:
logger.info("Exporting text encoder with OpenVINO backend...")

sd_model_component = StableDiffusionComponent.TEXT_ENCODER

# Get wrapped model and dummy inputs
text_encoder_wrapper = self.model_loader.get_text_encoder_wrapper()
dummy_inputs = self.model_loader.get_dummy_inputs()

# Export to ATEN graph
exported_program = export(
text_encoder_wrapper, dummy_inputs["text_encoder"]
)
component_dummy_inputs = dummy_inputs[sd_model_component]
exported_program = export(text_encoder_wrapper, component_dummy_inputs)

if self.is_quantization_enabled:
exported_program = self.compress_model(
exported_program, component_dummy_inputs
)

# Configure OpenVINO compilation
compile_spec = [CompileSpec("device", device.encode())]
Expand Down Expand Up @@ -85,13 +248,20 @@ def export_unet(self, output_path: str, device: str = "CPU") -> bool:
"""Export UNet model to PTE file"""
try:
logger.info("Exporting UNet model with OpenVINO backend...")
sd_model_component = StableDiffusionComponent.UNET

# Get wrapped model and dummy inputs
unet_wrapper = self.model_loader.get_unet_wrapper()
dummy_inputs = self.model_loader.get_dummy_inputs()

# Export to ATEN graph
exported_program = export(unet_wrapper, dummy_inputs["unet"])
component_dummy_inputs = dummy_inputs[sd_model_component]
exported_program = export(unet_wrapper, component_dummy_inputs)

if self.is_quantization_enabled:
exported_program = self.quantize_unet_model(
exported_program, component_dummy_inputs
)

# Configure OpenVINO compilation
compile_spec = [CompileSpec("device", device.encode())]
Expand Down Expand Up @@ -125,13 +295,20 @@ def export_vae_decoder(self, output_path: str, device: str = "CPU") -> bool:
"""Export VAE decoder to PTE file"""
try:
logger.info("Exporting VAE decoder with OpenVINO backend...")
sd_model_component = StableDiffusionComponent.VAE_DECODER

# Get wrapped model and dummy inputs
vae_decoder = self.model_loader.get_vae_decoder()
dummy_inputs = self.model_loader.get_dummy_inputs()

# Export to ATEN graph
exported_program = export(vae_decoder, dummy_inputs["vae_decoder"])
component_dummy_inputs = dummy_inputs[sd_model_component]
exported_program = export(vae_decoder, component_dummy_inputs)

if self.is_quantization_enabled:
exported_program = self.compress_model(
exported_program, component_dummy_inputs
)

# Configure OpenVINO compilation
compile_spec = [CompileSpec("device", device.encode())]
Expand Down Expand Up @@ -223,9 +400,23 @@ def create_argument_parser():

parser.add_argument(
"--dtype",
choices=["fp16", "fp32"],
choices=["fp16", "fp32", "int8"],
default="fp16",
help="Model data type (default: fp16)",
help="Model data type. Use int8 to enable PTQ quantization (default: fp16)",
)

parser.add_argument(
"--calibration_dataset_name",
type=str,
default="google-research-datasets/conceptual_captions",
help="HuggingFace dataset used for UNet calibration when INT8 quantization is enabled",
)

parser.add_argument(
"--calibration_dataset_column",
type=str,
default="caption",
help="Dataset column name used as prompt text for UNet calibration",
)

parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
Expand All @@ -249,11 +440,18 @@ def main() -> int:
logger.info("=" * 60)

# Map dtype string to torch dtype
dtype_map = {"fp16": torch.float16, "fp32": torch.float32}
is_quantization_enabled = args.dtype == "int8"
dtype_map = {"fp16": torch.float16, "fp32": torch.float32, "int8": torch.float32}
dtype = dtype_map[args.dtype]

# Create exporter and load models
exporter = LCMOpenVINOExporter(args.model_id, dtype=dtype)
exporter = LCMOpenVINOExporter(
args.model_id,
is_quantization_enabled=is_quantization_enabled,
dtype=dtype,
calibration_dataset_name=args.calibration_dataset_name,
calibration_dataset_column=args.calibration_dataset_column,
)

if not exporter.load_models():
logger.error("Failed to load models")
Expand Down
2 changes: 1 addition & 1 deletion examples/openvino/stable_diffusion/openvino_lcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def create_argument_parser():
"--device", choices=["CPU", "GPU"], default="CPU", help="Target device"
)
parser.add_argument(
"--dtype", choices=["fp16", "fp32"], default="fp16", help="Model dtype"
"--dtype", choices=["fp16", "fp32", "int8"], default="fp16", help="Model dtype"
)
parser.add_argument(
"--output_dir", type=str, default="./lcm_outputs", help="Output directory"
Expand Down
Loading
Loading