Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,13 @@ def _build_parser():

parser.add_argument("-v", "--verbose", action="store_true")

parser.add_argument(
"--calibration_num_threads",
type=int,
default=0,
help="Thread count for calibration forward passes. 0 = auto-tune (default).",
)

return parser


Expand Down
76 changes: 69 additions & 7 deletions examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import inspect
import json
import logging
import os
import time
import types

from functools import partial
Expand Down Expand Up @@ -412,6 +414,53 @@ def _tag_ios(self, node, fixed_point_type):

return quant_io_type

def _auto_tune_calibration_threads(self):
"""Find the optimal thread count for calibration via quick microbenchmark.
AR1 decode calibration is SGEMV-dominated (memory-bandwidth-bound).
The default thread count (os.cpu_count()) is typically far too high,
causing massive OpenMP sync overhead. This runs a few forward passes
at candidate thread counts and picks the fastest.
"""
physical_cores = os.cpu_count() // 2
# Sweep from deep in the linear-scaling region through the
# bandwidth-saturation knee up to the current default. DRAM
# bandwidth typically saturates at 8-16 cores on x86; threads
# beyond that add only OpenMP sync overhead for SGEMV workloads.
candidates = sorted(
{
max(1, physical_cores // 8),
max(1, physical_cores // 4),
max(2, physical_cores // 2),
max(2, physical_cores * 3 // 4),
physical_cores,
physical_cores + physical_cores // 2,
os.cpu_count(),
}
)
original = torch.get_num_threads()
best_threads, best_time = original, float("inf")
for t in candidates:
torch.set_num_threads(t)
with torch.no_grad():
self.decoder(*self.export_input) # warmup
t0 = time.perf_counter()
for _ in range(3):
self.decoder(*self.export_input)
elapsed = time.perf_counter() - t0
if elapsed < best_time:
best_threads, best_time = t, elapsed
torch.set_num_threads(original)
logging.info(
"Auto-tune calibration threads: tested %s, best=%d (%.1fms/fwd vs %.1fms at %d)",
candidates,
best_threads,
best_time / 3 * 1000,
float("inf"),
original,
)
return best_threads

def _calibrate(
self,
model,
Expand Down Expand Up @@ -560,14 +609,27 @@ def quantize(self, request: Request): # noqa: C901

# start calibration (only for kv mode or prefill mode without kv cache)
if self.mode == Mode.DECODE or not self.model_args.use_kv_cache:
self._calibrate(
model=self.decoder,
tokenizer=data.tokenizer,
event="prepare_pt2e",
user_calibration_data=data.calibration_data.datasets,
tok_embedding=self.tok_embedding,
intermediate_outputs=image_embedding,
calib_threads = getattr(self.control_args, "calibration_num_threads", 0)
if calib_threads <= 0:
calib_threads = self._auto_tune_calibration_threads()
original_threads = torch.get_num_threads()
torch.set_num_threads(calib_threads)
logging.info(
"Calibration using %d threads (was %d)",
calib_threads,
original_threads,
)
try:
self._calibrate(
model=self.decoder,
tokenizer=data.tokenizer,
event="prepare_pt2e",
user_calibration_data=data.calibration_data.datasets,
tok_embedding=self.tok_embedding,
intermediate_outputs=image_embedding,
)
finally:
torch.set_num_threads(original_threads)
else:
# one dummy inference to remove affine observer
# error happened in convert_pt2e
Expand Down
Loading