Skip to content

Main#18182

Open
mitrobik wants to merge 5 commits intopytorch:mainfrom
mitrobik:main
Open

Main#18182
mitrobik wants to merge 5 commits intopytorch:mainfrom
mitrobik:main

Conversation

@mitrobik
Copy link

Summary

[PLEASE REMOVE] See CONTRIBUTING.md's Pull Requests for ExecuTorch PR guidelines.

[PLEASE REMOVE] If this PR closes an issue, please add a Fixes #<issue-id> line.

[PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: " label. For a list of available release notes labels, check out CONTRIBUTING.md's Pull Requests.

Test plan

[PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable.

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 14, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18182

Note: Links to docs will display an error until the docs builds have been completed.

⚠️ 9 Awaiting Approval

As of commit add9c87 with merge base 8bec69b (image):

AWAITING APPROVAL - The following workflows need approval before CI can run:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla
Copy link

meta-cla bot commented Mar 14, 2026

Hi @mitrobik!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@github-actions
Copy link

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Author

@mitrobik mitrobik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright (c) Meta Platforms, Inc. and affiliates.

Copyright 2024 Arm Limited and/or its affiliates.

All rights reserved.

This source code is licensed under the BSD-style license found in the

LICENSE file in the root directory of this source tree.

Part of this code is from pybind11 cmake_example:

https://github.com/pybind/cmake_example/blob/master/setup.py so attach the

license below.

Copyright (c) 2016 The Pybind Development Team, All rights reserved.

Redistribution and use in source and binary forms, with or without

modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this

list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,

this list of conditions and the following disclaimer in the documentation

and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its contributors

may be used to endorse or promote products derived from this software

without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND

ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED

WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE

DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE

FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL

DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR

SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER

CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,

OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

You are under no obligation whatsoever to provide any bug fixes, patches, or

upgrades to the features, functionality or performance of the source code

("Enhancements") to anyone; however, if you choose to make your Enhancements

available either publicly, or directly to the author of this software, without

imposing a separate written license agreement for such Enhancements, then you

hereby grant the following license: a non-exclusive, royalty-free perpetual

license to install, use, modify, prepare derivative works, incorporate into

other computer software, distribute, and sublicense such enhancements or

derivative works thereof, in binary and source code form.

import contextlib

Import this before distutils so that setuptools can intercept the distuils

imports.

import importlib.util
import logging
import os
import re
import shutil
import site
import subprocess
import sys
from distutils import log # type: ignore[import-not-found]
from distutils.sysconfig import get_python_lib # type: ignore[import-not-found]
from pathlib import Path
from typing import List, Optional

Clean dynamic import using importlib

_install_utils_path = Path(file).parent / "install_utils.py"
_spec = importlib.util.spec_from_file_location("install_utils", _install_utils_path)
if _spec is None:
raise ImportError(f"Could not create module spec for {_install_utils_path}")
install_utils = importlib.util.module_from_spec(_spec)
if _spec.loader is None:
raise ImportError(f"Module spec has no loader for {_install_utils_path}")
_spec.loader.exec_module(install_utils)

    from setuptools import Extension, setup
    from setuptools.command.build import build
    from setuptools.command.build_ext import build_ext
    from setuptools.command.build_py import build_py
    
    logging.basicConfig(
        level=logging.INFO,
            format="%(asctime)s [%(levelname)s] %(message)s",
            )
            
            try:
                from tools.cmake.cmake_cache import CMakeCache
                except ImportError:
                    sys.path.insert(
                            0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "tools", "cmake")
                                )
                                    from cmake_cache import CMakeCache  # type: ignore[no-redef, import-not-found]
                                    
                                    
                                    def _is_macos() -> bool:
                                        return sys.platform == "darwin"
                                        
                                        
                                        def _is_windows() -> bool:
                                            return sys.platform == "win32"
                                            
                                            
                                            class Version:
                                                """Static strings that describe the version of the pip package."""
                                                
                                                    # Cached values returned by the properties.
                                                        __root_dir_attr: Optional[str] = None
                                                            __string_attr: Optional[str] = None
                                                                __git_hash_attr: Optional[str] = None
                                                                
                                                                    @classmethod
                                                                        def _root_dir(cls) -> str:
                                                                                """The path to the root of the git repo."""
                                                                                        if cls.__root_dir_attr is None:
                                                                                                    # This setup.py file lives in the root of the repo.
                                                                                                                cls.__root_dir_attr = str(Path(__file__).parent.resolve())
                                                                                                                        return str(cls.__root_dir_attr)
                                                                                                                        
                                                                                                                            @classmethod
                                                                                                                                def git_hash(cls) -> Optional[str]:
                                                                                                                                        """The current git hash, if known."""
                                                                                                                                                if cls.__git_hash_attr is None:
                                                                                                                                                            import subprocess
                                                                                                                                                            
                                                                                                                                                                        try:
                                                                                                                                                                                        cls.__git_hash_attr = (
                                                                                                                                                                                                            subprocess.check_output(
                                                                                                                                                                                                                                    ["git", "rev-parse", "HEAD"], cwd=cls._root_dir()
                                                                                                                                                                                                                                                        )
                                                                                                                                                                                                                                                                            .decode("ascii")
                                                                                                                                                                                                                                                                                                .strip()
                                                                                                                                                                                                                                                                                                                )
                                                                                                                                                                                                                                                                                                                            except subprocess.CalledProcessError:
                                                                                                                                                                                                                                                                                                                                            cls.__git_hash_attr = ""  # Non-None but empty.
                                                                                                                                                                                                                                                                                                                                                    # A non-None but empty value indicates that we don't know it.
                                                                                                                                                                                                                                                                                                                                                            return cls.__git_hash_attr if cls.__git_hash_attr else None
                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                @classmethod
                                                                                                                                                                                                                                                                                                                                                                    def string(cls) -> str:
                                                                                                                                                                                                                                                                                                                                                                            """The version string."""
                                                                                                                                                                                                                                                                                                                                                                                    if cls.__string_attr is None:
                                                                                                                                                                                                                                                                                                                                                                                                # If set, BUILD_VERSION should override any local version
                                                                                                                                                                                                                                                                                                                                                                                                            # information. CI will use this to manage, e.g., release vs. nightly
                                                                                                                                                                                                                                                                                                                                                                                                                        # versions.
                                                                                                                                                                                                                                                                                                                                                                                                                                    version = os.getenv("BUILD_VERSION", "").strip()
                                                                                                                                                                                                                                                                                                                                                                                                                                                if not version:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                # Otherwise, read the version from a local file and add the git
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                # commit if available.
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                version = (
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    open(os.path.join(cls._root_dir(), "version.txt")).read().strip()
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    )
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    if cls.git_hash():
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        version += "+" + cls.git_hash()[:7]  # type: ignore[index]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    cls.__string_attr = version

@mitrobik mitrobik marked this pull request as ready for review March 14, 2026 08:10
#
# Please refer to the license found in the LICENSE file in the root directory of the source tree.

import argparse
import sys
import tempfile
from pathlib import Path
from typing import List, Tuple

import coremltools as ct
from executorch.backends.apple.coreml.compiler import CoreMLBackend
from executorch.exir import EdgeProgramManager

from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.tracer import Value
from tabulate import tabulate


def get_root_dir_path() -> Path:
    return Path(__file__).resolve().parent.parent.parent.parent.parent


    sys.path.append(str((get_root_dir_path() / "examples").resolve()))

    from inspector_utils import (
        build_sdk_runner_including_coreml,
            ComparisonResult,
                create_inspector_coreml,
                    create_inspector_reference,
                        get_comparison_result,
                            module_to_edge,
                            )

                            from models import MODEL_NAME_TO_MODEL
                            from models.model_factory import EagerModelFactory


                            def args_parser() -> argparse.ArgumentParser:
                                parser = argparse.ArgumentParser()

                                    parser.add_argument(
                                            "-m",
                                                    "--model_name",
                                                            required=True,
                                                                    help=f"Provide model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}",
                                                                        )

                                                                            parser.add_argument(
                                                                                    "-c",
                                                                                            "--compute_unit",
                                                                                                    required=False,
                                                                                                            default=ct.ComputeUnit.ALL.name.lower(),
                                                                                                                    help=f"Provide compute unit for the model. Valid ones: {[[compute_unit.name.lower() for compute_unit in ct.ComputeUnit]]}",
                                                                                                                        )

                                                                                                                            parser.add_argument(
                                                                                                                                    "-precision",
                                                                                                                                            "--compute_precision",
                                                                                                                                                    required=False,
                                                                                                                                                            default=ct.precision.FLOAT16.value,
                                                                                                                                                                    help=f"Provide compute precision for the model. Valid ones: {[[precision.value for precision in ct.precision]]}",
                                                                                                                                                                        )

                                                                                                                                                                            parser.add_argument(
                                                                                                                                                                                    "--compile",
                                                                                                                                                                                            action=argparse.BooleanOptionalAction,
                                                                                                                                                                                                    required=False,
                                                                                                                                                                                                            default=False,
                                                                                                                                                                                                                )

                                                                                                                                                                                                                    parser.add_argument(
                                                                                                                                                                                                                            "-env",
                                                                                                                                                                                                                                    "--conda_environment_name",
                                                                                                                                                                                                                                            required=False,
                                                                                                                                                                                                                                                    default="executorch",
                                                                                                                                                                                                                                                            help="Provide conda environment name.",
                                                                                                                                                                                                                                                                )

                                                                                                                                                                                                                                                                    return parser


                                                                                                                                                                                                                                                                    def get_compile_specs_from_args(args):
                                                                                                                                                                                                                                                                        model_type = CoreMLBackend.MODEL_TYPE.MODEL
                                                                                                                                                                                                                                                                            if args.compile:
                                                                                                                                                                                                                                                                                    model_type = CoreMLBackend.MODEL_TYPE.COMPILED_MODEL

                                                                                                                                                                                                                                                                                        compute_precision = ct.precision(args.compute_precision)
                                                                                                                                                                                                                                                                                            compute_unit = ct.ComputeUnit[args.compute_unit.upper()]

                                                                                                                                                                                                                                                                                                return CoreMLBackend.generate_compile_specs(
                                                                                                                                                                                                                                                                                                        compute_precision=compute_precision,
                                                                                                                                                                                                                                                                                                                compute_unit=compute_unit,
                                                                                                                                                                                                                                                                                                                        model_type=model_type,
                                                                                                                                                                                                                                                                                                                                minimum_deployment_target=ct.target.iOS17,
                                                                                                                                                                                                                                                                                                                                    )


                                                                                                                                                                                                                                                                                                                                    def compare_intermediate_tensors(
                                                                                                                                                                                                                                                                                                                                        edge_program: EdgeProgramManager,
                                                                                                                                                                                                                                                                                                                                            example_inputs: Tuple[Value, ...],
                                                                                                                                                                                                                                                                                                                                                coreml_compile_specs: List[CompileSpec],
                                                                                                                                                                                                                                                                                                                                                    model_name: str,
                                                                                                                                                                                                                                                                                                                                                        working_dir_path: Path,
                                                                                                                                                                                                                                                                                                                                                        ) -> ComparisonResult:
                                                                                                                                                                                                                                                                                                                                                            inspector_coreml = create_inspector_coreml(
                                                                                                                                                                                                                                                                                                                                                                    edge_program=edge_program,
                                                                                                                                                                                                                                                                                                                                                                            compile_specs=coreml_compile_specs,
                                                                                                                                                                                                                                                                                                                                                                                    example_inputs=example_inputs,
                                                                                                                                                                                                                                                                                                                                                                                            model_name=model_name,
                                                                                                                                                                                                                                                                                                                                                                                                    working_dir_path=working_dir_path,
                                                                                                                                                                                                                                                                                                                                                                                                            root_dir_path=get_root_dir_path(),
                                                                                                                                                                                                                                                                                                                                                                                                                )

                                                                                                                                                                                                                                                                                                                                                                                                                    inspector_reference = create_inspector_reference(
                                                                                                                                                                                                                                                                                                                                                                                                                            edge_program=edge_program,
                                                                                                                                                                                                                                                                                                                                                                                                                                    example_inputs=example_inputs,
                                                                                                                                                                                                                                                                                                                                                                                                                                            model_name=model_name,
                                                                                                                                                                                                                                                                                                                                                                                                                                                    working_dir_path=working_dir_path,
                                                                                                                                                                                                                                                                                                                                                                                                                                                            root_dir_path=get_root_dir_path(),
                                                                                                                                                                                                                                                                                                                                                                                                                                                                )

                                                                                                                                                                                                                                                                                                                                                                                                                                                                    return get_comparison_result(
                                                                                                                                                                                                                                                                                                                                                                                                                                                                            inspector1=inspector_reference,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    tag1="reference",
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            inspector2=inspector_coreml,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    tag2="coreml",
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        )


                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        def main() -> None:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            parser = args_parser()
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                args = parser.parse_args()

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    if args.model_name not in MODEL_NAME_TO_MODEL:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            raise RuntimeError(
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        f"Model {args.model_name} is not a valid name. "
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            )

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                valid_compute_units = [compute_unit.name.lower() for compute_unit in ct.ComputeUnit]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    if args.compute_unit not in valid_compute_units:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            raise RuntimeError(
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        f"{args.compute_unit} is invalid. "
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    f"Valid compute units are {valid_compute_units}."
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            )

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                build_sdk_runner_including_coreml(
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        root_dir_path=get_root_dir_path(), conda_env_name=args.conda_environment_name
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            )

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                model, example_inputs, _ = EagerModelFactory.create_model(
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        *MODEL_NAME_TO_MODEL[args.model_name]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            )

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                model.eval()
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    edge_program = module_to_edge(
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            module=model,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    example_inputs=example_inputs,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        )

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            coreml_compile_specs = get_compile_specs_from_args(args)

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                with tempfile.TemporaryDirectory() as temp_dir_name:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        working_dir_path = Path(temp_dir_name) / "debugger"
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                working_dir_path.mkdir(parents=True, exist_ok=True)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        comparison_result = compare_intermediate_tensors(
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    edge_program=edge_program,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                example_inputs=example_inputs,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            coreml_compile_specs=coreml_compile_specs,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        model_name=args.model_name,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    working_dir_path=working_dir_path,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            )

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    print(
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                tabulate(comparison_result.to_dataframe(), headers="keys", tablefmt="grid")
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        )


                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        if __name__ == "__main__":
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            main()  # pragma: no cover
@meta-cla
Copy link

meta-cla bot commented Mar 14, 2026

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 14, 2026
@mitrobik mitrobik requested a review from mergennachin as a code owner March 14, 2026 14:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant