Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
#
# ==============================================================================

.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu qwen3_5_moe-cuda clean help

help:
@echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make <target>\`. Available targets:"
Expand All @@ -118,6 +118,7 @@ help:
@echo " llava-cpu - Build Llava runner with CPU backend"
@echo " gemma3-cuda - Build Gemma3 runner with CUDA backend"
@echo " gemma3-cpu - Build Gemma3 runner with CPU backend"
@echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend"
@echo " clean - Clean build artifacts"

voxtral-cuda:
Expand Down Expand Up @@ -332,6 +333,15 @@ gemma3-cpu:
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/gemma3/gemma3_e2e_runner"

qwen3_5_moe-cuda:
@echo "==> Building and installing ExecuTorch with CUDA..."
cmake --workflow --preset llm-release-cuda
@echo "==> Building Qwen3.5 MoE runner with CUDA..."
cd examples/models/qwen3_5_moe && cmake --workflow --preset qwen3-5-moe-cuda
@echo ""
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner"

clean:
rm -rf cmake-out \
extension/llm/tokenizers/build \
Expand Down
78 changes: 78 additions & 0 deletions examples/models/qwen3_5_moe/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

cmake_minimum_required(VERSION 3.24)
project(qwen3_5_moe)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)

include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)

set(_common_include_directories ${EXECUTORCH_ROOT}/..)

# gflags
set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags)
find_package(gflags REQUIRED)

# executorch
list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..)
find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH)
executorch_target_link_options_shared_lib(executorch)

set(link_libraries executorch gflags)

# CPU ops
list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas)
executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib)

# Needed for cpuinfo where it uses android specific log lib
if(ANDROID)
list(APPEND link_libraries log)
endif()

# Extensions
list(APPEND link_libraries
extension_llm_runner
extension_module
extension_data_loader
extension_tensor
extension_flat_tensor
)

# CUDA backend (required)
find_package(CUDAToolkit REQUIRED)
list(APPEND link_libraries aoti_cuda_backend)
if(NOT MSVC)
executorch_target_link_options_shared_lib(aoti_cuda_backend)
endif()

# Tokenizer
list(APPEND link_libraries tokenizers::tokenizers)

add_executable(qwen3_5_moe_runner main.cpp)
target_include_directories(qwen3_5_moe_runner PUBLIC ${_common_include_directories})
target_link_libraries(qwen3_5_moe_runner PUBLIC ${link_libraries})

if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
target_link_options_gc_sections(qwen3_5_moe_runner)
if(NOT APPLE AND NOT MSVC)
target_link_options(qwen3_5_moe_runner PRIVATE "LINKER:-s")
endif()
endif()

# On Windows, copy required DLLs to the executable directory
if(MSVC AND EXECUTORCH_BUILD_CUDA)
add_custom_command(
TARGET qwen3_5_moe_runner
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:aoti_cuda_shims>
$<TARGET_FILE_DIR:qwen3_5_moe_runner>
COMMENT "Copying aoti_cuda_shims.dll to qwen3_5_moe_runner directory"
)
endif()
52 changes: 52 additions & 0 deletions examples/models/qwen3_5_moe/CMakePresets.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
{
"version": 6,
"configurePresets": [
{
"name": "qwen3-5-moe-base",
"hidden": true,
"binaryDir": "${sourceDir}/../../../cmake-out/examples/models/qwen3_5_moe",
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Release",
"CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out",
"CMAKE_PREFIX_PATH": "${sourceDir}/../../../cmake-out"
}
},
{
"name": "qwen3-5-moe-cuda",
"displayName": "Qwen3.5 MoE runner (CUDA)",
"inherits": ["qwen3-5-moe-base"],
"cacheVariables": {
"EXECUTORCH_BUILD_CUDA": "ON"
},
"condition": {
"type": "inList",
"string": "${hostSystemName}",
"list": ["Linux", "Windows"]
}
}
],
"buildPresets": [
{
"name": "qwen3-5-moe-cuda",
"displayName": "Build Qwen3.5 MoE runner (CUDA)",
"configurePreset": "qwen3-5-moe-cuda",
"targets": ["qwen3_5_moe_runner"]
}
],
"workflowPresets": [
{
"name": "qwen3-5-moe-cuda",
"displayName": "Configure and build Qwen3.5 MoE runner (CUDA)",
"steps": [
{
"type": "configure",
"name": "qwen3-5-moe-cuda"
},
{
"type": "build",
"name": "qwen3-5-moe-cuda"
}
]
}
]
}
160 changes: 160 additions & 0 deletions examples/models/qwen3_5_moe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Qwen 3.5 MoE

Self-contained ExecuTorch implementation of
[Qwen3.5-MoE-A3B](https://huggingface.co/Qwen/Qwen3.5-MoE-A3B),
a ~35B total / ~3B active parameter Mixture-of-Experts language model.
Weights are loaded directly from the HuggingFace safetensors checkpoint.
CUDA backend only.

## Overview

The pipeline has two stages: **export** (Python, once) and **inference**
(C++ runner, repeated). Export converts the HuggingFace checkpoint into a
`model.pte` file with int4 quantization. At inference time, the C++ runner
loads the `.pte`, `.ptd`, and a HuggingFace tokenizer, then generates text.

## Architecture

Qwen 3.5 MoE is a hybrid-attention transformer with sparse Mixture of Experts:

```
Input tokens
|
v
Token Embedding (no learned position embedding — RoPE is inside attention)
|
v
+--- Decoder Layer x40 -----------------------------------------+
| |
| GemmaRMSNorm -> Attention (hybrid) -> residual add |
| +- 75% of layers: GatedDeltaNet (linear, O(n)) |
| +- 25% of layers: Full Attention (softmax, O(n^2)) |
| |
| GemmaRMSNorm -> Sparse MoE -> residual add |
| +- Router: softmax -> top-8 expert selection |
| +- 256 routed experts: independent SwiGLU MLPs |
| +- Shared expert: always-on SwiGLU with sigmoid gate |
| |
+----------------------------------------------------------------+
|
v
GemmaRMSNorm -> LM Head -> logits
```

### Key parameters

| Parameter | Value |
|-----------|-------|
| `hidden_size` | 2048 |
| `num_hidden_layers` | 40 |
| `num_attention_heads` / `num_kv_heads` | 16 / 2 |
| `head_dim` | 256 |
| `partial_rotary_factor` | 0.25 (64 of 256 dims rotated) |
| `linear_num_key_heads` / `linear_num_value_heads` | 16 / 32 |
| `linear_key_head_dim` / `linear_value_head_dim` | 128 / 128 |
| `num_experts` / `num_experts_per_tok` | 256 / 8 |
| `moe_intermediate_size` | 512 |
| `vocab_size` | 248320 |

### Key components

| Component | Description |
|-----------|-------------|
| **GemmaRMSNorm** | `x / sqrt(mean(x^2) + eps) * (1 + weight)` — unit-offset variant |
| **Full Attention** | GQA with output gate (sigmoid), QK-norm (GemmaRMSNorm), partial RoPE (25% of dims) |
| **GatedDeltaNet** | Linear attention via recurrent state. Mamba-style gating: `g = -exp(A_log) * softplus(a + dt_bias)`. Causal conv1d, L2-normalized Q/K, delta rule recurrence. Uses FLA Triton kernel. |
| **Sparse MoE** | Router selects top-8 of 256 experts per token. Shared expert with sigmoid gate always runs. |

## Prerequisites

- ExecuTorch installed from source (see [building from source](../../../docs/source/using-executorch-building-from-source.md))
- [safetensors](https://pypi.org/project/safetensors/) (`pip install safetensors`)
- NVIDIA GPU with CUDA toolkit
- Model weights downloaded from HuggingFace. The directory should contain
`config.json`, `model.safetensors.index.json`, safetensors shards, and
`tokenizer.json`.

## Export

Export produces a `model.pte` and `aoti_cuda_blob.ptd` containing the
compiled CUDA kernels and quantized weights. Int4 quantization is
recommended — the model is too large to fit in VRAM at bf16.

```bash
python export.py \
--model-dir ~/models/Qwen3.5-MoE-A3B \
--output-dir ./qwen35_moe_exports \
--qlinear 4w \
--qlinear-packing-format tile_packed_to_4d
```

### Options

| Flag | Default | Description |
|------|---------|-------------|
| `--model-dir` | (required) | HuggingFace model directory with `config.json` + safetensors |
| `--output-dir` | `./qwen35_moe_exports` | Output directory |
| `--max-seq-len` | `4096` | KV cache length |
| `--qlinear` | (none) | Linear layer quantization: `4w`, `8w`, `8da4w`, `8da8w` |
| `--qlinear-group-size` | `32` | Group size for linear quantization |
| `--qlinear-packing-format` | (none) | Packing format for 4w: `tile_packed_to_4d` |
| `--qembedding` | (none) | Embedding quantization: `8w` |

## Build

ExecuTorch must be installed from source first (see
[Prerequisites](#prerequisites)). The `make` target handles building
core libraries and the runner binary.

```bash
make qwen3_5_moe-cuda
```

This builds ExecuTorch with CUDA backend support, then the runner binary
at `cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner`.

## Run

The runner requires:
- `model.pte` — exported model (see [Export](#export))
- `aoti_cuda_blob.ptd` — CUDA delegate data file (produced during export)
- `tokenizer.json` — HuggingFace tokenizer from the model weights directory

```bash
cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner \
--model_path qwen35_moe_exports/model.pte \
--data_path qwen35_moe_exports/aoti_cuda_blob.ptd \
--tokenizer_path ~/models/Qwen3.5-MoE-A3B/tokenizer.json \
--prompt "The meaning of life is" \
--max_new_tokens 128
```

| Flag | Default | Description |
|------|---------|-------------|
| `--model_path` | (required) | Path to exported `.pte` model |
| `--data_path` | (none) | Path to `.ptd` delegate data file (required for CUDA) |
| `--tokenizer_path` | (required) | Path to HuggingFace `tokenizer.json` |
| `--prompt` | `"Hello"` | Input prompt text |
| `--temperature` | `0.8` | Sampling temperature (0 = greedy) |
| `--max_new_tokens` | `128` | Maximum tokens to generate |

## Files

| File | Description |
|------|-------------|
| `model.py` | Export-friendly model definition with all components |
| `export.py` | Export + quantize + lower to CUDA `.pte` |
| `main.cpp` | C++ runner using ExecuTorch's TextLLMRunner |
| `CMakeLists.txt` | Build configuration |
| `CMakePresets.json` | CMake presets for CUDA build |

## Troubleshooting

- **OOM during export**: The model requires significant GPU memory even
with int4 quantization. Try reducing `--max-seq-len` or using a GPU
with more VRAM.
- **OOM during loading**: The 35B parameter model requires ~70 GB RAM to
load from safetensors before quantization. Ensure sufficient system
memory.
- **Missing `aoti_cuda_blob.ptd`**: This file is produced during export
alongside the `.pte`. Both files are required for inference.
Empty file.
Loading
Loading