diff --git a/README.md b/README.md index 28f4efd1..5bb67558 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ xmake && xmake install - 运行模型推理测试 ```bash -python scripts/jiuge.py [--cpu | --nvidia | --qy | --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] path/to/model_dir [n_device] +python scripts/jiuge.py [--cpu | --nvidia | --qy | --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon | --ali] path/to/model_dir [n_device] ``` - 部署模型推理服务 @@ -63,6 +63,12 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA ``` + - 选择是否使用kv caching,默认为false;在支持了此算子的平台(英伟达、阿里、天数、沐曦、海光、QY)可以使用 + ```bash + xmake f --use-kv-caching= [true | false] -cv + ``` + + - 安装 InfiniLM Python 包 ```bash pip install -e . @@ -71,11 +77,11 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA - 单次推理测试 - llama示例 ```bash - python examples/llama.py [--cpu | --nvidia | --qy | --metax | --moore | --iluvatar | --ali] --model_path= + python examples/jiuge.py [--cpu | --nvidia | --qy | --metax | --moore | --iluvatar | --ali | --cambricon | --hygon] --model_path= ``` - 例如: ```bash - python examples/llama.py --nvidia --model_path=/models/TinyLlama-1.1B-Chat-v1.0 + python examples/jigue.py --nvidia --model_path=/models/TinyLlama-1.1B-Chat-v1.0 ``` - 分布式推理测试 - 9g示例 @@ -113,7 +119,7 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA - 运行推理基准测试(C-Eval/MMLU) ```bash - python test/bench/test_benchmark.py [--cpu | --nvidia | --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] --bench {ceval|mmlu} [--backend cpp] [--ndev N] [--subject SUBJECT] [--num_samples N] [--max_new_tokens N] [--output_csv PATH] [--cache_dir PATH] + python test/bench/test_benchmark.py [--cpu | --nvidia | --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon | --ali] --bench {ceval|mmlu} [--backend cpp] [--ndev N] [--subject SUBJECT] [--num_samples N] [--max_new_tokens N] [--output_csv PATH] [--cache_dir PATH] ``` - 参数说明: @@ -154,3 +160,21 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA python test/bench/test_benchmark.py --nvidia /models/9G7B_MHA --bench mmlu --subject abstract_algebra --backend cpp --ndev 1 --cache_dir ~/.cache/huggingface/datasets/ ``` > 注意:`--cache_dir` 应指向包含 `ceval___ceval-exam` 和 `cais___mmlu` 等数据集子目录的父目录,而不是直接指向这些子目录 + + - 试验中功能 + - Warm Up + ```bash + python examples/bench.py --nvidia --model= --warmup + ``` + - Paged Attention + ```bash + python examples/bench.py --nvidia --model= --enable-paged-attn + ``` + - CUDA Graph + ```bash + python examples/bench.py --nvidia --model= --enable-paged-attn --enable-graph + ``` + - 选择attention后端 (使用flash attention后端需要先在InfiniCore完成相关配置和编译) + ```bash + python examples/bench.py --nvidia --model= --enable-paged-attn [--attn=default | --attn=flash-attn] + ``` diff --git a/csrc/backends/attention_backends.hpp b/csrc/backends/attention_backends.hpp new file mode 100644 index 00000000..5cf66305 --- /dev/null +++ b/csrc/backends/attention_backends.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +namespace infinilm::backends { + +enum class AttentionBackend { + Default, + FlashAttn, +}; + +inline AttentionBackend parse_attention_backend(const std::string &backend) { + if (backend == "default") { + return AttentionBackend::Default; + } + if (backend == "flash-attn") { + return AttentionBackend::FlashAttn; + } + + throw std::invalid_argument( + "Invalid attention_backend: " + backend + ". Valid options are: default, flash-attn"); +} + +} // namespace infinilm::backends diff --git a/csrc/cache/kv_cache.cpp b/csrc/cache/kv_cache.cpp index 9c3f0bcc..5f0de647 100644 --- a/csrc/cache/kv_cache.cpp +++ b/csrc/cache/kv_cache.cpp @@ -93,26 +93,24 @@ StaticKVCache::update(size_t layer_idx, auto device = k_cache_layer->device(); - if (device.getType() == infinicore::Device::Type::NVIDIA - || device.getType() == infinicore::Device::Type::ILUVATAR - || device.getType() == infinicore::Device::Type::METAX) { - infinicore::op::kv_caching_( - k_cache_layer, - v_cache_layer, - k, - v, - past_sequence_lengths); - } else { - size_t cache_pos = reinterpret_cast(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0]; - auto result_len = cache_pos + update_len; - ASSERT(result_len <= cache_len_); - - auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}}); - auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}}); - - k_cache_update->copy_from(k); - v_cache_update->copy_from(v); - } +#ifdef ENABLE_KV_CACHING + infinicore::op::kv_caching_( + k_cache_layer, + v_cache_layer, + k, + v, + past_sequence_lengths); +#else + size_t cache_pos = reinterpret_cast(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0]; + auto result_len = cache_pos + update_len; + ASSERT(result_len <= cache_len_); + + auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}}); + auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}}); + + k_cache_update->copy_from(k); + v_cache_update->copy_from(v); +#endif return {k_cache_layer, v_cache_layer}; } @@ -215,9 +213,9 @@ PagedKVCache::get_contiguous_kv( const infinicore::Tensor cache_lens, const infinicore::Tensor input_offsets, size_t request_id) { - ASSERT_EQ(block_tables->dtype(), infinicore::DataType::I64); - ASSERT_EQ(cache_lens->dtype(), infinicore::DataType::I64); - ASSERT_EQ(input_offsets->dtype(), infinicore::DataType::I64); + ASSERT_EQ(block_tables->dtype(), infinicore::DataType::I32); + ASSERT_EQ(cache_lens->dtype(), infinicore::DataType::I32); + ASSERT_EQ(input_offsets->dtype(), infinicore::DataType::I32); auto nreq = block_tables->size(0); auto block_tables_cpu = block_tables->to(infinicore::Device::cpu()); @@ -229,9 +227,9 @@ PagedKVCache::get_contiguous_kv( auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx); auto req = request_id; - auto cache_lens_ptr = reinterpret_cast(cache_lens_cpu->data()); - auto input_offsets_ptr = reinterpret_cast(input_offsets_cpu->data()); - int64_t total_len = cache_lens_ptr[req] + (input_offsets_ptr[req + 1] - input_offsets_ptr[req]); + auto cache_lens_ptr = reinterpret_cast(cache_lens_cpu->data()); + auto input_offsets_ptr = reinterpret_cast(input_offsets_cpu->data()); + int32_t total_len = cache_lens_ptr[req] + (input_offsets_ptr[req + 1] - input_offsets_ptr[req]); auto full_k = infinicore::Tensor::empty( {num_rank_k_heads_, (size_t)total_len, k_dim_}, @@ -245,7 +243,7 @@ PagedKVCache::get_contiguous_kv( size_t r = total_len % block_size_; for (size_t b = 0; b < nblocks; b++) { - size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, b, 1}})->data())); + size_t bid = *((int32_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, b, 1}})->data())); full_k->narrow({{1, b * block_size_, block_size_}}) ->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0)); @@ -254,7 +252,7 @@ PagedKVCache::get_contiguous_kv( } if (r > 0) { - size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, nblocks, 1}})->data())); + size_t bid = *((int32_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, nblocks, 1}})->data())); full_k->narrow({{1, nblocks * block_size_, r}}) ->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}})); diff --git a/csrc/cache/kv_cache.hpp b/csrc/cache/kv_cache.hpp index 54b84cc1..a594008a 100644 --- a/csrc/cache/kv_cache.hpp +++ b/csrc/cache/kv_cache.hpp @@ -86,7 +86,7 @@ class PagedKVCacheConfig final : public CacheConfig { public: PagedKVCacheConfig( size_t num_blocks, - size_t block_size = 16); + size_t block_size = 256); std::unique_ptr unique_copy() const override; size_t num_blocks() const; diff --git a/csrc/engine/compiler/paged_compiler.cpp b/csrc/engine/compiler/paged_compiler.cpp index 74616c0d..8f47f749 100644 --- a/csrc/engine/compiler/paged_compiler.cpp +++ b/csrc/engine/compiler/paged_compiler.cpp @@ -34,26 +34,27 @@ void PagedCompiler::compile() { size_t max_batch_size = *std::max_element(decode_batch_sizes_.begin(), decode_batch_sizes_.end()); compiled_map_decode_.clear(); block_tables_holder_ = infinicore::Tensor::empty( - {nblocks}, infinicore::DataType::I64, infinicore::context::getDevice()); + {nblocks}, infinicore::DataType::I32, infinicore::context::getDevice()); set_zeros(block_tables_holder_); for (size_t b : decode_batch_sizes_) { size_t block_per_req = nblocks / b; InfinilmModel::Input input; input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice()); input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); - input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); + input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I32, infinicore::context::getDevice()); set_zeros(input.input_ids.value()); set_zeros(input.position_ids.value()); set_zeros(input.total_sequence_lengths.value()); - std::vector total_sequence_lengths_vec(b, 1); - infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false); - input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I64, infinicore::context::getDevice()); - set_zeros(input.input_offsets.value()); - std::vector input_offsets_vec(b + 1, 0); + std::vector total_sequence_lengths_vec(b, 1); + infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int32_t), false); + input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I32, infinicore::context::getDevice()); + std::vector input_offsets_vec(b + 1, 0); for (size_t i = 0; i <= b; i++) { input_offsets_vec[i] = i; } - infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int64_t), false); + infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int32_t), false); + input.cu_seqlens = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I32, infinicore::context::getDevice()); + infinicore::context::memcpyH2D(input.cu_seqlens.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int32_t), false); input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1}); input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); set_zeros(input.slot_mapping.value()); @@ -91,6 +92,7 @@ PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input & graph_input.position_ids.value()->copy_from(input.position_ids.value()); graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value()); graph_input.input_offsets.value()->copy_from(input.input_offsets.value()); + graph_input.cu_seqlens.value()->copy_from(input.cu_seqlens.value()); graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value()); graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value()); diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 4a2d5e86..7dd76eb8 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -23,9 +23,11 @@ InferEngine::InferEngine( const distributed::DistConfig &distributed_config, infinicore::Device::Type device_type, const cache::CacheConfig *cache_config, - bool enable_graph_compiling) // Changed parameter + bool enable_graph_compiling, + backends::AttentionBackend attention_backend) // Changed parameter : communication_group_(distributed_config, device_type), - legacy_model_config_(config) { + legacy_model_config_(config), + attention_backend_(attention_backend) { if (cache_config != nullptr) { cache_config_ = cache_config->unique_copy(); } @@ -39,7 +41,8 @@ InferEngine::InferEngine( communication_group_.get_rank_info(r), cache_config_ != nullptr ? cache_config_.get() : nullptr, barrier_.get(), - enable_graph_compiling)); + enable_graph_compiling, + attention_backend_)); } // Compile the model on all workers @@ -51,8 +54,9 @@ InferEngine::InferEngine( const distributed::DistConfig &distributed_config, infinicore::Device::Type device_type, const cache::CacheConfig *cache_config, - bool enable_graph_compiling) // Changed parameter - : communication_group_(distributed_config, device_type) { + bool enable_graph_compiling, + backends::AttentionBackend attention_backend) // Changed parameter + : communication_group_(distributed_config, device_type), attention_backend_(attention_backend) { if (cache_config != nullptr) { cache_config_ = cache_config->unique_copy(); } @@ -69,7 +73,8 @@ InferEngine::InferEngine( communication_group_.get_rank_info(r), cache_config_ != nullptr ? cache_config_.get() : nullptr, barrier_.get(), - enable_graph_compiling)); + enable_graph_compiling, + attention_backend_)); } // Compile the model on all workers this->compile(); @@ -117,6 +122,7 @@ InferEngine::Input::to_model_input(infinicore::Device device) const { to_device(past_sequence_lengths), // @todo: on device in the future to_device(total_sequence_lengths), to_device(input_offsets), + to_device(cu_seqlens), to_device(block_tables), to_device(slot_mapping), }; @@ -169,7 +175,7 @@ void InferEngine::reset_cache(const cache::CacheConfig *new_config) { for (auto &worker : workers_) { worker->wait(); } - + cache_config_ = new_config->unique_copy(); this->compile(); } diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index 22e428ec..d191cfa1 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -37,14 +37,16 @@ class InferEngine { const distributed::DistConfig &distributed_config = distributed::DistConfig(), infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), const cache::CacheConfig *cache_config = nullptr, - bool enable_graph_compiling = false); + bool enable_graph_compiling = false, + backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); InferEngine( const std::string &model_path = "", const distributed::DistConfig &distributed_config = distributed::DistConfig(), infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), const cache::CacheConfig *cache_config = nullptr, - bool enable_graph_compiling = false); + bool enable_graph_compiling = false, + backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); // Load a parameter to all workers (each can extract its shard inside RankWorker) void load_param(const std::string &name, const infinicore::Tensor ¶m); @@ -73,6 +75,7 @@ class InferEngine { std::unique_ptr cache_config_; const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config(); std::shared_ptr model_config_; + backends::AttentionBackend attention_backend_ = backends::AttentionBackend::Default; }; } // namespace infinilm::engine diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 3a2f53ec..c9cdc97f 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -26,9 +26,11 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config, const distributed::RankInfo &rank_info, const cache::CacheConfig *cache_config, RankBarrier *barrier, - bool enable_graph_compiling) + bool enable_graph_compiling, + backends::AttentionBackend attention_backend) : legacy_model_config_(model_config), rank_info_(rank_info), + attention_backend_(attention_backend), enable_graph_compiling_(enable_graph_compiling), job_cmd_(Command::INIT), has_job_(false), @@ -53,9 +55,11 @@ RankWorker::RankWorker( const distributed::RankInfo &rank_info, const cache::CacheConfig *cache_config, RankBarrier *barrier, - bool enable_graph_compiling) + bool enable_graph_compiling, + backends::AttentionBackend attention_backend) : model_config_(model_config), rank_info_(rank_info), + attention_backend_(attention_backend), enable_graph_compiling_(enable_graph_compiling), job_cmd_(Command::INIT), has_job_(false), @@ -234,10 +238,18 @@ void RankWorker::thread_loop() { // Create model using factory (may be expensive) if (model_config_ == nullptr) { - model_ = InfinilmModelFactory::createModel(legacy_model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr); + model_ = InfinilmModelFactory::createModel( + legacy_model_config_, + rank_info_, + pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr, + attention_backend_); } else { - model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr); + model_ = InfinilmModelFactory::createModel( + model_config_, + rank_info_, + pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr, + attention_backend_); } if (!model_) { @@ -339,7 +351,7 @@ void RankWorker::thread_loop() { const auto &batch_size{logits_shape[0]}; auto n_req = local_args.input_offsets.value()->size(0) - 1; - int64_t *input_offsets = (int64_t *)local_args.input_offsets.value()->data(); + int32_t *input_offsets = (int32_t *)local_args.input_offsets.value()->data(); auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)}; diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index f738ec1f..719c1cb1 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -1,5 +1,6 @@ #pragma once +#include "../backends/attention_backends.hpp" #include "../cache/cache.hpp" #include "../config/model_config.hpp" #include "../models/model_factory.hpp" @@ -37,8 +38,10 @@ class RankWorker { std::optional past_sequence_lengths; /// ToTal Lengths for each request sequence, of shape `[num_requests]`. std::optional total_sequence_lengths; - /// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`. + /// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`. std::optional input_offsets; + /// Cumulative total sequence lengths for each request, of shape `[num_requests + 1]`. + std::optional cu_seqlens; /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache. std::optional block_tables; /// Slot ids for each token `[seq]`. Used for paged cache. @@ -61,13 +64,15 @@ class RankWorker { const distributed::RankInfo &rank_info, const cache::CacheConfig *cache_config, RankBarrier *barrier, - bool enable_graph_compiling); + bool enable_graph_compiling, + backends::AttentionBackend attention_backend); RankWorker(std::shared_ptr model_config, const distributed::RankInfo &rank_info, const cache::CacheConfig *cache_config, RankBarrier *barrier, - bool enable_graph_compiling); + bool enable_graph_compiling, + backends::AttentionBackend attention_backend); // Submit a parameter load job and wait until the load completes on the worker thread. void load_param(const std::string &name, @@ -107,6 +112,9 @@ class RankWorker { std::shared_ptr model_; std::shared_ptr cache_; + // Backends + backends::AttentionBackend attention_backend_; + // Graph Compiling bool enable_graph_compiling_; std::unique_ptr compiler_; diff --git a/csrc/models/infinilm_model.hpp b/csrc/models/infinilm_model.hpp index be7ebd0d..550bf1aa 100644 --- a/csrc/models/infinilm_model.hpp +++ b/csrc/models/infinilm_model.hpp @@ -27,6 +27,8 @@ class InfinilmModel : public infinicore::nn::Module { std::optional total_sequence_lengths; /// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`. std::optional input_offsets; + /// Cumulative total sequence lengths for each request, of shape `[num_requests + 1]`. + std::optional cu_seqlens; /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache. std::optional block_tables; /// Slot ids for each token `[seq]`. Used for paged cache. diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp index a6b5ab78..fe76479b 100644 --- a/csrc/models/llama/llama_attention.cpp +++ b/csrc/models/llama/llama_attention.cpp @@ -4,6 +4,7 @@ #include "infinicore/nn/linear.hpp" #include "infinicore/nn/rope.hpp" #include "infinicore/ops.hpp" +#include "infinicore/ops/mha_varlen.hpp" #include "infinicore/ops/mul.hpp" #include @@ -31,7 +32,8 @@ namespace infinilm::models::llama { LlamaAttention::LlamaAttention(const LlamaConfig &config, const infinicore::Device &device, size_t layer_idx, - engine::distributed::RankInfo rank_info) + engine::distributed::RankInfo rank_info, + backends::AttentionBackend attention_backend) : layer_idx_(layer_idx), hidden_size_(config.hidden_size), num_attention_heads_(config.num_attention_heads), @@ -41,7 +43,9 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, use_bias_(config.attention_bias), use_output_bias_(config.attention_output_bias), use_qk_norm_(config.qk_norm), - max_position_embeddings_(config.max_position_embeddings), rank_info_(rank_info) { + max_position_embeddings_(config.max_position_embeddings), + rank_info_(rank_info), + attention_backend_(attention_backend) { const auto &dtype{config.dtype}; int tp_rank = rank_info.tp_rank; @@ -75,7 +79,8 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, LlamaAttention::LlamaAttention(std::shared_ptr model_config, const infinicore::Device &device, size_t layer_idx, - engine::distributed::RankInfo rank_info) + engine::distributed::RankInfo rank_info, + backends::AttentionBackend attention_backend) : model_config_(model_config), layer_idx_(layer_idx), hidden_size_(model_config->get("hidden_size")), @@ -86,7 +91,8 @@ LlamaAttention::LlamaAttention(std::shared_ptr mo use_bias_(model_config->get_or("attention_bias", true)), use_output_bias_(model_config->get_or("attention_output_bias", false)), max_position_embeddings_(model_config->get("max_position_embeddings")), - rank_info_(rank_info) { + rank_info_(rank_info), + attention_backend_(attention_backend) { const auto &dtype{model_config_->get_dtype()}; int tp_rank = rank_info.tp_rank; @@ -203,7 +209,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta ->contiguous() ->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim] } else { - size_t total_seq_len = reinterpret_cast(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0]; + size_t total_seq_len = reinterpret_cast(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0]; k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim] v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim] @@ -238,6 +244,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd std::shared_ptr paged_kv_cache, std::optional total_sequence_lengths, std::optional input_offsets, + std::optional cu_seqlens, std::optional block_tables, std::optional slot_mapping) const { ASSERT(block_tables.has_value()); @@ -298,17 +305,31 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device()); if (is_prefill) { - infinicore::op::paged_attention_prefill_( - attn_output, - q_reshaped, - k_total, - v_total, - block_tables.value(), - total_sequence_lengths.value(), - input_offsets.value(), - std::nullopt, - scaling_); - + if (attention_backend_ == backends::AttentionBackend::FlashAttn) { + infinicore::op::mha_varlen_( + attn_output, + q_reshaped, + k_total->permute({0, 2, 1, 3}), + v_total->permute({0, 2, 1, 3}), + input_offsets.value(), + cu_seqlens.value(), + block_tables.value(), + max_position_embeddings_, + max_position_embeddings_, + std::nullopt, + scaling_); + } else { + infinicore::op::paged_attention_prefill_( + attn_output, + q_reshaped, + k_total, + v_total, + block_tables.value(), + total_sequence_lengths.value(), + input_offsets.value(), + std::nullopt, + scaling_); + } } else { infinicore::op::paged_attention_( attn_output, @@ -322,7 +343,8 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd } // 7. Project output - attn_output = attn_output->view({1, seq_len, num_attention_heads_ * head_dim_}); + attn_output + = attn_output->view({1, seq_len, num_attention_heads_ * head_dim_}); return o_proj_->forward(attn_output); } @@ -332,6 +354,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat std::optional past_sequence_lengths, std::optional total_sequence_lengths, std::optional input_offsets, + std::optional cu_seqlens, std::optional block_tables, std::optional slot_mapping) const { if (!rotary_emb_) { @@ -340,7 +363,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat infinicore::Tensor output; if (auto paged_kv_cache = std::dynamic_pointer_cast(kv_cache)) { - output = forward_paged_(hidden_states, position_ids, paged_kv_cache, total_sequence_lengths, input_offsets, block_tables, slot_mapping); + output = forward_paged_(hidden_states, position_ids, paged_kv_cache, total_sequence_lengths, input_offsets, cu_seqlens, block_tables, slot_mapping); } else { output = forward_(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths); diff --git a/csrc/models/llama/llama_attention.hpp b/csrc/models/llama/llama_attention.hpp index 0f8f9a90..4fe369d6 100644 --- a/csrc/models/llama/llama_attention.hpp +++ b/csrc/models/llama/llama_attention.hpp @@ -1,5 +1,6 @@ #pragma once +#include "../../backends/attention_backends.hpp" #include "../../cache/kv_cache.hpp" #include "../../config/model_config.hpp" #include "../../engine/distributed/distributed.hpp" @@ -52,12 +53,14 @@ class LlamaAttention : public infinicore::nn::Module { LlamaAttention(const LlamaConfig &config, const infinicore::Device &device, size_t layer_idx, - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); LlamaAttention(std::shared_ptr model_config, const infinicore::Device &device, size_t layer_idx, - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); /** * @brief Forward pass: compute attention @@ -73,6 +76,7 @@ class LlamaAttention : public infinicore::nn::Module { std::optional past_sequence_lengths, std::optional total_sequence_lengths, std::optional input_offsets, + std::optional cu_seqlens, std::optional block_tables, std::optional slot_mapping) const; @@ -104,6 +108,7 @@ class LlamaAttention : public infinicore::nn::Module { std::shared_ptr kv_cache, std::optional total_sequence_lengths, std::optional input_offsets, + std::optional cu_seqlens, std::optional block_tables, std::optional slot_mapping) const; @@ -132,6 +137,8 @@ class LlamaAttention : public infinicore::nn::Module { size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility) float scaling_; + + backends::AttentionBackend attention_backend_; }; } // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_decoder_layer.cpp b/csrc/models/llama/llama_decoder_layer.cpp index 208771d2..752b3016 100644 --- a/csrc/models/llama/llama_decoder_layer.cpp +++ b/csrc/models/llama/llama_decoder_layer.cpp @@ -19,7 +19,8 @@ namespace infinilm::models::llama { LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, const infinicore::Device &device, size_t layer_idx, - engine::distributed::RankInfo rank_info) : layer_idx_(layer_idx), rank_info_(rank_info) { + engine::distributed::RankInfo rank_info, + backends::AttentionBackend attention_backend) : layer_idx_(layer_idx), rank_info_(rank_info) { const auto &dtype{config.dtype}; // Initialize layer normalization layers @@ -29,14 +30,15 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, dtype, device); // Initialize attention and MLP modules - INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, rank_info_); + INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, rank_info_, attention_backend); INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_); } LlamaDecoderLayer::LlamaDecoderLayer(std::shared_ptr model_config, const infinicore::Device &device, size_t layer_idx, - engine::distributed::RankInfo rank_info) : model_config_(model_config), layer_idx_(layer_idx), rank_info_(rank_info) { + engine::distributed::RankInfo rank_info, + backends::AttentionBackend attention_backend) : model_config_(model_config), layer_idx_(layer_idx), rank_info_(rank_info) { const auto &dtype{model_config_->get_dtype()}; // Initialize layer normalization layers INFINICORE_NN_MODULE_INIT(input_layernorm, model_config_->get("hidden_size"), model_config_->get("rms_norm_eps"), @@ -45,7 +47,7 @@ LlamaDecoderLayer::LlamaDecoderLayer(std::shared_ptr past_sequence_lengths, std::optional total_sequence_lengths, std::optional input_offsets, + std::optional cu_seqlens, std::optional block_tables, std::optional slot_mapping) const { // 1. Attention layer normalization input_layernorm_->forward_inplace(hidden_states, residual); // 2. Self-attention - hidden_states = self_attn_->forward(hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, block_tables, slot_mapping); + hidden_states = self_attn_->forward( + hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, cu_seqlens, block_tables, slot_mapping); // 3. Post-attention layer normalization post_attention_layernorm_->forward_inplace(hidden_states, residual); diff --git a/csrc/models/llama/llama_decoder_layer.hpp b/csrc/models/llama/llama_decoder_layer.hpp index a56aec03..9f2826b0 100644 --- a/csrc/models/llama/llama_decoder_layer.hpp +++ b/csrc/models/llama/llama_decoder_layer.hpp @@ -48,12 +48,14 @@ class LlamaDecoderLayer : public infinicore::nn::Module { LlamaDecoderLayer(const LlamaConfig &config, const infinicore::Device &device, size_t layer_idx, - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); LlamaDecoderLayer(std::shared_ptr model_config, const infinicore::Device &device, size_t layer_idx, - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); /** * @brief Forward pass: process one decoder layer @@ -73,6 +75,7 @@ class LlamaDecoderLayer : public infinicore::nn::Module { std::optional past_sequence_lengths, std::optional total_sequence_lengths, std::optional input_offsets, + std::optional cu_seqlens, std::optional block_tables, std::optional slot_mappin) const; diff --git a/csrc/models/llama/llama_for_causal_lm.cpp b/csrc/models/llama/llama_for_causal_lm.cpp index 50a39b43..9596e668 100644 --- a/csrc/models/llama/llama_for_causal_lm.cpp +++ b/csrc/models/llama/llama_for_causal_lm.cpp @@ -17,13 +17,14 @@ namespace infinilm::models::llama { */ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, const infinicore::Device &device, - engine::distributed::RankInfo rank_info) { + engine::distributed::RankInfo rank_info, + backends::AttentionBackend attention_backend) { // Initialize module's device_ member device_ = device; const auto &dtype{config.dtype}; // Initialize base model - INFINICORE_NN_MODULE_INIT(model, config, device, rank_info); + INFINICORE_NN_MODULE_INIT(model, config, device, rank_info, attention_backend); // Initialize language modeling head // Note: If tie_word_embeddings is true, we would share weights with embed_tokens @@ -34,14 +35,15 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, LlamaForCausalLM::LlamaForCausalLM(std::shared_ptr model_config, const infinicore::Device &device, - engine::distributed::RankInfo rank_info) { + engine::distributed::RankInfo rank_info, + backends::AttentionBackend attention_backend) { // Initialize module's device_ member device_ = device; const auto &dtype{model_config->get_dtype()}; // Initialize base model - INFINICORE_NN_MODULE_INIT(model, model_config, device, rank_info); + INFINICORE_NN_MODULE_INIT(model, model_config, device, rank_info, attention_backend); // Initialize language modeling head // Note: If tie_word_embeddings is true, we would share weights with embed_tokens // For now, we create a separate linear layer @@ -56,12 +58,13 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const { auto past_sequence_lengths = input.past_sequence_lengths; auto total_sequence_length = input.total_sequence_lengths; auto input_offsets = input.input_offsets; + auto cu_seqlens = input.cu_seqlens; auto block_tables = input.block_tables; auto slot_mapping = input.slot_mapping; // 1. Forward through base model to get hidden states auto hidden_states = model_->forward( - input_ids, position_ids, past_sequence_lengths, total_sequence_length, input_offsets, block_tables, slot_mapping); + input_ids, position_ids, past_sequence_lengths, total_sequence_length, input_offsets, cu_seqlens, block_tables, slot_mapping); // 2. Apply language modeling head to get logits auto logits = lm_head_->forward(hidden_states); diff --git a/csrc/models/llama/llama_for_causal_lm.hpp b/csrc/models/llama/llama_for_causal_lm.hpp index a6e078e7..5cc79dfe 100644 --- a/csrc/models/llama/llama_for_causal_lm.hpp +++ b/csrc/models/llama/llama_for_causal_lm.hpp @@ -42,11 +42,13 @@ class LlamaForCausalLM : public InfinilmModel { */ LlamaForCausalLM(const LlamaConfig &config, const infinicore::Device &device, - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); LlamaForCausalLM(std::shared_ptr model_config, const infinicore::Device &device, - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); /** * @brief Forward pass: compute language modeling logits diff --git a/csrc/models/llama/llama_mlp.cpp b/csrc/models/llama/llama_mlp.cpp index a3ab7859..282e2eca 100644 --- a/csrc/models/llama/llama_mlp.cpp +++ b/csrc/models/llama/llama_mlp.cpp @@ -71,19 +71,34 @@ LlamaMLP::LlamaMLP(std::shared_ptr model_config, } infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const { - // 1. Project to gate and up - auto hidden_states_mutable = hidden_states; - auto [gate, up] = gate_up_proj_->forward_split(hidden_states_mutable); + infinicore::Device::Type dev_type = hidden_states->device().getType(); + if(dev_type == infinicore::Device::Type::MOORE){ + // 1. Project to a single combined gate_up tensor + auto hidden_states_mutable = hidden_states; + auto gate_up = gate_up_proj_->forward(hidden_states_mutable); - // 2. Apply SwiGLU: silu(gate) * up - // Note: swiglu kernel expects (up, gate) and computes gate * sigmoid(gate) * up - // So we pass (up, gate) to get the correct result: gate * sigmoid(gate) * up - auto intermediate = infinicore::op::swiglu(up, gate); + // 2. Apply the fused silu_and_mul operator + // applies SiLU to the first half, and multiplies it by the second half. + // Mathematically equivalent to: result = SiLU(gate_up[..., :d]) * gate_up[..., d:] + auto intermediate = infinicore::op::silu_and_mul(gate_up); - // 3. Project down - auto output = down_proj_->forward(intermediate); + // 3. Project down + auto output = down_proj_->forward(intermediate); + return output; + } else{ + // 1. Project to gate and up + auto hidden_states_mutable = hidden_states; + auto [gate, up] = gate_up_proj_->forward_split(hidden_states_mutable); - return output; + // 2. Apply SwiGLU: silu(gate) * up + // Note: swiglu kernel expects (up, gate) and computes gate * sigmoid(gate) * up + // So we pass (up, gate) to get the correct result: gate * sigmoid(gate) * up + auto intermediate = infinicore::op::swiglu(up, gate); + + // 3. Project down + auto output = down_proj_->forward(intermediate); + return output; + } } } // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_model.cpp b/csrc/models/llama/llama_model.cpp index c1c5eefb..81e8fd04 100644 --- a/csrc/models/llama/llama_model.cpp +++ b/csrc/models/llama/llama_model.cpp @@ -20,7 +20,8 @@ namespace infinilm::models::llama { */ LlamaModel::LlamaModel(const LlamaConfig &config, const infinicore::Device &device, - engine::distributed::RankInfo rank_info) + engine::distributed::RankInfo rank_info, + backends::AttentionBackend attention_backend) : config_(config), rank_info_(rank_info) { const auto &dtype{config.dtype}; // Initialize token embeddings @@ -34,7 +35,7 @@ LlamaModel::LlamaModel(const LlamaConfig &config, layers_.reserve(config.num_hidden_layers); for (size_t i = 0; i < config.num_hidden_layers; ++i) { layers_.push_back(this->register_module( - "layers." + std::to_string(i), config, device, i, rank_info)); + "layers." + std::to_string(i), config, device, i, rank_info, attention_backend)); } // Initialize final layer normalization @@ -56,7 +57,8 @@ LlamaModel::LlamaModel(const LlamaConfig &config, LlamaModel::LlamaModel(std::shared_ptr model_config, const infinicore::Device &device, - engine::distributed::RankInfo rank_info) + engine::distributed::RankInfo rank_info, + backends::AttentionBackend attention_backend) : model_config_(model_config), rank_info_(rank_info) { const auto &dtype{model_config_->get_dtype()}; // Initialize token embeddings @@ -69,7 +71,7 @@ LlamaModel::LlamaModel(std::shared_ptr model_conf layers_.reserve(model_config_->get("num_hidden_layers")); for (size_t i = 0; i < model_config_->get("num_hidden_layers"); ++i) { layers_.push_back(this->register_module( - "layers." + std::to_string(i), model_config_, device, i, rank_info)); + "layers." + std::to_string(i), model_config_, device, i, rank_info, attention_backend)); } // Initialize final layer normalization INFINICORE_NN_MODULE_INIT(norm, model_config_->get("hidden_size"), model_config_->get("rms_norm_eps"), @@ -92,6 +94,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, std::optional past_sequence_lengths, std::optional total_sequence_lengths, std::optional input_offsets, + std::optional cu_seqlens, std::optional block_tables, std::optional slot_mapping) const { // 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size] @@ -109,6 +112,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, past_sequence_lengths, total_sequence_lengths, input_offsets, + cu_seqlens, block_tables, slot_mapping); } diff --git a/csrc/models/llama/llama_model.hpp b/csrc/models/llama/llama_model.hpp index f293a97a..416e1a5c 100644 --- a/csrc/models/llama/llama_model.hpp +++ b/csrc/models/llama/llama_model.hpp @@ -51,11 +51,13 @@ class LlamaModel : public infinicore::nn::Module { */ LlamaModel(const LlamaConfig &config, const infinicore::Device &device, - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); LlamaModel(std::shared_ptr model_config, const infinicore::Device &device, - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); /** * @brief Forward pass: process input through the model @@ -73,6 +75,7 @@ class LlamaModel : public infinicore::nn::Module { std::optional past_sequence_lengths, std::optional total_sequence_lengths, std::optional input_offsets, + std::optional cu_seqlens, std::optional block_tables, std::optional slot_mapping) const; diff --git a/csrc/models/model_factory.cpp b/csrc/models/model_factory.cpp index 89ea715e..319a7baa 100644 --- a/csrc/models/model_factory.cpp +++ b/csrc/models/model_factory.cpp @@ -17,12 +17,13 @@ namespace infinilm { std::shared_ptr InfinilmModelFactory::createModel( const InfinilmModel::Config &config, engine::distributed::RankInfo rank_info, - const cache::CacheConfig *cache) { + const cache::CacheConfig *cache, + backends::AttentionBackend attention_backend) { std::shared_ptr model; if (const auto llama_config_ptr = dynamic_cast(&config)) { const auto &llama_config = *llama_config_ptr; model = std::make_shared( - llama_config, rank_info.device, rank_info); + llama_config, rank_info.device, rank_info, attention_backend); } else { throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type"); } @@ -37,12 +38,13 @@ std::shared_ptr InfinilmModelFactory::createModel( std::shared_ptr InfinilmModelFactory::createModel( std::shared_ptr model_config, engine::distributed::RankInfo rank_info, - const cache::CacheConfig *cache) { + const cache::CacheConfig *cache, + backends::AttentionBackend attention_backend) { std::shared_ptr model; if (true) { model = std::make_shared( - model_config, rank_info.device, rank_info); + model_config, rank_info.device, rank_info, attention_backend); } else { throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type"); } diff --git a/csrc/models/model_factory.hpp b/csrc/models/model_factory.hpp index 02385029..3c3c2e38 100644 --- a/csrc/models/model_factory.hpp +++ b/csrc/models/model_factory.hpp @@ -3,6 +3,7 @@ #include "../config/model_config.hpp" #include "infinilm_model.hpp" +#include "../backends/attention_backends.hpp" #include "../engine/distributed/distributed.hpp" namespace infinilm { @@ -23,11 +24,13 @@ class InfinilmModelFactory { static std::shared_ptr createModel( const InfinilmModel::Config &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), - const cache::CacheConfig *cache = nullptr); + const cache::CacheConfig *cache = nullptr, + backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); static std::shared_ptr createModel( std::shared_ptr model_config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), - const cache::CacheConfig *cache = nullptr); + const cache::CacheConfig *cache = nullptr, + backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); }; } // namespace infinilm diff --git a/csrc/pybind11/cache/cache.hpp b/csrc/pybind11/cache/cache.hpp index d9f1985c..46f97ea4 100644 --- a/csrc/pybind11/cache/cache.hpp +++ b/csrc/pybind11/cache/cache.hpp @@ -37,7 +37,7 @@ inline void bind_cache(py::module &m) { .def( py::init(), py::arg("num_blocks"), - py::arg("block_size") = 16) + py::arg("block_size") = 256) .def( "num_blocks", &infinilm::cache::PagedKVCacheConfig::num_blocks) diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 78af5daa..4aeec8af 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -36,19 +36,22 @@ inline void bind_infer_engine(py::module &m) { const distributed::DistConfig &dist, infinicore::Device::Type dev, std::shared_ptr cache_cfg, - bool enable_graph_compiling) { + bool enable_graph_compiling, + const std::string &attention_backend) { return std::make_shared( cfg, dist, dev, cache_cfg ? cache_cfg.get() : nullptr, - enable_graph_compiling); + enable_graph_compiling, + infinilm::backends::parse_attention_backend(attention_backend)); }), py::arg("config"), py::arg("distributed_config") = distributed::DistConfig(), py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("cache_config") = py::none(), - py::arg("enable_graph_compiling") = false) + py::arg("enable_graph_compiling") = false, + py::arg("attention_backend") = "default") .def("load_param", &InferEngine::load_param, py::arg("name"), py::arg("param"), "Load a parameter tensor into all workers (each worker picks its shard)") @@ -63,11 +66,14 @@ inline void bind_infer_engine(py::module &m) { } return state_dict_tp_all; }) - .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") - .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) - .def("get_cache_config", [](const InferEngine &self) { + .def( + "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") + .def( + "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) + .def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr { auto cfg = self.get_cache_config(); - return std::shared_ptr(std::move(cfg->unique_copy())); }) + return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; + }) .def("__repr__", [](const InferEngine &self) { return ""; }); infer_engine @@ -76,19 +82,22 @@ inline void bind_infer_engine(py::module &m) { const distributed::DistConfig &dist, infinicore::Device::Type dev, std::shared_ptr cache_cfg, - bool enable_graph_compiling) { + bool enable_graph_compiling, + const std::string &attention_backend) { return std::make_shared( model_path, dist, dev, cache_cfg ? cache_cfg.get() : nullptr, - enable_graph_compiling); + enable_graph_compiling, + infinilm::backends::parse_attention_backend(attention_backend)); }), py::arg("model_path") = "", py::arg("distributed_config") = distributed::DistConfig(), py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("cache_config") = py::none(), - py::arg("enable_graph_compiling") = false) + py::arg("enable_graph_compiling") = false, + py::arg("attention_backend") = "default") .def("load_param", &InferEngine::load_param, py::arg("name"), py::arg("param"), "Load a parameter tensor into all workers (each worker picks its shard)") @@ -103,8 +112,10 @@ inline void bind_infer_engine(py::module &m) { } return state_dict_tp_all; }) - .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") - .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) + .def( + "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") + .def( + "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) { auto cfg = self.get_cache_config(); return std::shared_ptr(std::move(cfg->unique_copy())); }) @@ -118,6 +129,7 @@ inline void bind_infer_engine(py::module &m) { std::optional past_sequence_lengths, std::optional total_sequence_lengths, std::optional input_offsets, + std::optional cu_seqlens, std::optional block_tables, std::optional slot_mapping, py::kwargs kwargs) { @@ -127,6 +139,7 @@ inline void bind_infer_engine(py::module &m) { std::move(past_sequence_lengths), std::move(total_sequence_lengths), std::move(input_offsets), + std::move(cu_seqlens), std::move(block_tables), std::move(slot_mapping), }; @@ -167,6 +180,7 @@ inline void bind_infer_engine(py::module &m) { py::arg("past_sequence_lengths") = std::nullopt, py::arg("total_sequence_lengths") = std::nullopt, py::arg("input_offsets") = std::nullopt, + py::arg("cu_seqlens") = std::nullopt, py::arg("block_tables") = std::nullopt, py::arg("slot_mapping") = std::nullopt) .def_readwrite("input_ids", &InferEngine::Input::input_ids) @@ -174,6 +188,7 @@ inline void bind_infer_engine(py::module &m) { .def_readwrite("past_sequence_lengths", &InferEngine::Input::past_sequence_lengths) .def_readwrite("total_sequence_lengths", &InferEngine::Input::total_sequence_lengths) .def_readwrite("input_offsets", &InferEngine::Input::input_offsets) + .def_readwrite("cu_seqlens", &InferEngine::Input::cu_seqlens) .def_readwrite("block_tables", &InferEngine::Input::block_tables) .def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping) .def_readwrite("temperature", &InferEngine::Input::temperature) diff --git a/examples/bench.py b/examples/bench.py index 3f9de226..9ac2b11e 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -22,6 +22,8 @@ "float32": 4, } +_PAGED_KV_BLOCK_SIZE = 256 + # BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128] # INPUT_LENS = [32, 256, 1024, 4096] # OUTPUT_LENS = [256, 1024, 4096] @@ -167,6 +169,11 @@ def get_args(): action="store_true", help="Run alippu test", ) + parser.add_argument( + "--hygon", + action="store_true", + help="Run hygon test", + ) parser.add_argument( "--model", type=str, @@ -229,6 +236,12 @@ def get_args(): action="store_true", help="use paged cache", ) + parser.add_argument( + "--paged_kv_block_size", + type=int, + default=256, + help="num tokens each kv block can hold", + ) parser.add_argument( "--enable-graph", action="store_true", @@ -237,12 +250,20 @@ def get_args(): parser.add_argument( "--warmup", action="store_true", - help="Perform a warmup run before benchmarking/inference." + help="Perform a warmup run before benchmarking/inference.", + ) + parser.add_argument( + "--attn", + type=str, + default="default", + choices=["default", "flash-attn"], + help="attention backend to use: 'default' or 'flash-attn'", ) return parser.parse_args() -prompt = "泰山,又名岱山、岱宗、岱岳、东岳、泰岳,为五岳之一,有“五岳之首”、“五岳独尊”、“天下第一山”、“华夏神山”之称 ,被中外学者称为“中国的奥林匹斯山” 位于山东省中部,隶属于泰安市,绵亘于泰安、济南、淄博三市之间,总面积25000公顷,主峰玉皇顶海拔约1545米。泰山相伴上下五千年的华夏文明传承历史,集国家兴盛、民族存亡的象征于一身,是中华民族的精神家园 [31],东方文化的缩影,“天人合一”思想的寄托之地 [24],承载着丰厚的地理历史文化内涵 [15],被古人视为“直通帝座”的天堂,成为百姓崇拜,帝王告祭的神山,有“泰山安,四海皆安”的说法 [1]。自秦始皇起至清代,先后有13代帝王亲登泰山封禅或祭祀,另有24代帝王遣官祭祀72次。山体上既有寺庙、宫、观等古建筑群29处,古遗址128处,有大小碑碣、摩崖石刻2000余处 [15]。其景巍峨雄奇、幽奥俊秀,有石坞松涛、云海玉盘等美丽壮阔的自然景观。其历史文化、自然风光、地质奇观和谐融为一体,具有特殊的历史、文化、美学和科学价值。 [19]1982年,泰山被列入第一批国家级风景名胜区。1987年,泰山被联合国教科文组织批准列为全球首例世界文化与自然双重遗产 [14] [41-42]。2002年,泰山被评为“中华十大文化名山”之首 [15]。2005年,泰山成为国家地质公园。2006年,泰山因其独特的地质价值成为世界地质公园 [14]。2007年3月,泰山被评为国家AAAAA级旅游景区;12月,泰山被命名为中国首座“中国书法名山”。2025年3月20日,泰山迎来2025年第100万名游客。" +with open("examples/bench_prompt.md", "r") as f: + prompt = f.read() def repeat_prompt(input_ids: list[int], target_length: int): @@ -264,6 +285,7 @@ def __init__( skip_load=False, cache_config=None, enable_graph=False, + attn_backend="default", ) -> None: model_path = os.path.expanduser(model_path) # ---------------------------------------------------------------------------- # @@ -275,6 +297,7 @@ def __init__( distributed_config=DistConfig(tp), cache_config=cache_config, enable_graph_compiling=enable_graph, + attention_backend=attn_backend, ) # ---------------------------------------------------------------------------- # @@ -287,13 +310,13 @@ def __init__( # 创建 tokenizer # ---------------------------------------------------------------------------- # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - + if tokenizer.pad_token is None: if tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id else: - tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # ---------------------------------------------------------------------------- # # token编码 @@ -312,9 +335,8 @@ def __init__( input_content, padding=True, truncation=True, - max_length=2048, - return_tensors="pt" - ) + max_length=8192, + ) input_ids_list = encoding["input_ids"] @@ -349,6 +371,7 @@ def run( top_k=top_k, top_p=top_p, temperature=temperature, + stop_on_eos=False, ), _measure_and_log_time=True, ) @@ -386,11 +409,14 @@ def run( device_str = "mlu" elif args.ali: device_str = "cuda" + elif args.hygon: + device_str = "cuda" else: print( "python examples/bench.py --nvidia --model=~/TinyLlama-1.1B-Chat-v1.0/ --batch-size=2 --tp=1 --input-len=50 --output-len=50" ) sys.exit(1) + _PAGED_KV_BLOCK_SIZE = args.paged_kv_block_size # -------------------------------------------------------- # # 解析参数 # -------------------------------------------------------- # @@ -422,10 +448,14 @@ def run( # 测试 # -------------------------------------------------------- # if enable_paged_attn: - paged_kv_block_size = 16 + paged_kv_block_size = _PAGED_KV_BLOCK_SIZE max_num_blocks = max( [ - ((c_["input_len"] + c_["output_len"] + 15) // 16) * c_["batch_size"] + ( + (c_["input_len"] + c_["output_len"] + (paged_kv_block_size - 1)) + // paged_kv_block_size + ) + * c_["batch_size"] for _, c_ in cases_dict.items() ] ) @@ -440,6 +470,7 @@ def run( skip_load=skip_load, cache_config=cache_config, enable_graph=enable_graph, + attn_backend=args.attn, ) # ---------------------------------------------------------------------------- # @@ -459,10 +490,7 @@ def run( ) ) - avg_prompt_len = min( - 64, - max(len(ids) for ids in test.input_ids_list) - ) + avg_prompt_len = min(64, max(len(ids) for ids in test.input_ids_list)) warmup_ids = [ ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids @@ -477,10 +505,11 @@ def run( _ = test.model.generate( input_ids_infini, GenerationConfig( - max_new_tokens=5, # decode kernel warmup + max_new_tokens=5, # decode kernel warmup temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, + stop_on_eos=False, ), _measure_and_log_time=False, ) @@ -495,7 +524,6 @@ def run( # Warmup done # ---------------------------------------------------------------------------- # - for idx, case in tqdm(cases_dict.items(), desc="Processing cases"): tqdm.write(f"\033[92mProcessing : {case}\033[0m") diff --git a/examples/bench_prompt.md b/examples/bench_prompt.md new file mode 100644 index 00000000..9b71850c --- /dev/null +++ b/examples/bench_prompt.md @@ -0,0 +1,19 @@ +泰山,又名岱山、岱宗、岱岳、东岳、泰岳,为五岳之一,有“五岳之首”、“五岳独尊”、“天下第一山”、“华夏神山”之称,被中外学者称为“中国的奥林匹斯山”。它位于山东省中部,隶属于泰安市,绵亘于泰安、济南、淄博三市之间,总面积25000公顷,主峰玉皇顶海拔约1545米。这座雄伟的山脉,以其磅礴的气势和深厚的内涵,横亘在齐鲁大地上,成为中华民族精神家园中无可替代的坐标。泰山相伴上下五千年的华夏文明传承历史,集国家兴盛、民族存亡的象征于一身,是中华民族的精神家园,东方文化的缩影,“天人合一”思想的寄托之地,承载着丰厚的地理历史文化内涵,被古人视为“直通帝座”的天堂,成为百姓崇拜,帝王告祭的神山,有“泰山安,四海皆安”的说法。自秦始皇起至清代,先后有13代帝王亲登泰山封禅或祭祀,另有24代帝王遣官祭祀72次。山体上既有寺庙、宫、观等古建筑群29处,古遗址128处,有大小碑碣、摩崖石刻2000余处。其景巍峨雄奇、幽奥俊秀,有石坞松涛、云海玉盘等美丽壮阔的自然景观。其历史文化、自然风光、地质奇观和谐融为一体,具有特殊的历史、文化、美学和科学价值。1982年,泰山被列入第一批国家级风景名胜区。1987年,泰山被联合国教科文组织批准列为全球首例世界文化与自然双重遗产。2002年,泰山被评为“中华十大文化名山”之首。2005年,泰山成为国家地质公园。2006年,泰山因其独特的地质价值成为世界地质公园。2007年3月,泰山被评为国家AAAAA级旅游景区;12月,泰山被命名为中国首座“中国书法名山”。2025年3月20日,泰山迎来2025年第100万名游客。 +要真正理解泰山,必须首先从它的地理形胜与地质传奇开始。泰山崛起于华北平原之东,巍然矗立于齐鲁大地之上,其形成的历史可追溯至二十八亿年前的太古代时期。那时,这里还是一片浩瀚的海洋,地壳运动如同巨人的手掌,将海底的沉积岩层层挤压、褶皱、抬升,历经沧海桑田的巨变,终于在大约三千万年前的喜马拉雅造山运动中,奠定了今日泰山的雏形。泰山地区的岩石,以片麻岩、花岗岩和闪长岩为主,这些古老的岩石记录了地球童年时代的记忆,它们是整个华北地区最古老的地层之一,被地质学家称为“泰山杂岩”。登临泰山,抚摸那布满斑驳纹路的石壁,仿佛能够触碰到时间的骨骼,感受到地球脉动的余温。 +泰山的地貌格局,呈现出一种阶梯状上升的特征。从山脚的泰安城,海拔仅有二十余米,到中天门海拔八百余米,再到南天门的一千四百余米,最后抵达玉皇顶的一千五百四十五米,每一级台阶都是地质力量的杰作。泰山的山势陡峭而雄伟,但并不显得孤傲与冷漠。它像一位端坐的巨人,既有威严刚毅的面容,也有宽厚仁慈的胸怀。泰山的主峰玉皇顶,因建有玉皇庙而得名,立于其上,极目远眺,只见群山拱卫,众水朝宗,天高地迥,宇宙无穷。孔子“登泰山而小天下”的慨叹,杜甫“会当凌绝顶,一览众山小”的诗句,皆由此生发。这种视觉上的震撼与心理上的升华,正是泰山地理形胜给予世人最直接的馈赠。 +泰山的自然景观,堪称一部恢弘壮丽的交响乐章。四季更替为泰山披上不同的盛装,春日山花烂漫,夏日绿荫如盖,秋日红叶满山,冬日银装素裹。而最令人叹为观止的,当属泰山的四大奇观:泰山日出、云海玉盘、晚霞夕照、黄河金带。黎明时分,站在日观峰上,东方地平线渐渐泛起鱼肚白,须臾间,一道红光喷薄而出,继而半圆形、扇形的金色光芒四射,最后一轮红日跃出云海,天地为之色变。这壮丽的日出景象,自古以来就是帝王封禅告祭的神圣时刻,也是无数游人心驰神往的精神洗礼。若逢雨过天晴,云海便会悄然铺陈,茫茫云涛如素绢白纱,群峰若隐若现,宛如海上仙山。而当夕阳西下,晚霞映照山峦,整座泰山便笼罩在一片金红交织的光晕之中,瑰丽不可方物。更为罕见的是黄河金带,在天气极其晴朗的日子,站在泰山之巅向西眺望,但见黄河如一条金色的丝带,蜿蜒于天际之间,这是大自然赐予泰山独有的殊荣。 +泰山的植被资源极为丰富,森林覆盖率高达百分之八十以上,拥有各类植物一千余种。从山麓的侧柏林,到山腰的油松林,再到山顶的灌丛草甸,垂直分布的植被带谱清晰分明。其中,最负盛名的当属泰山“三大奇观”之一的“石坞松涛”。在后石坞,古松参天而立,枝干虬曲盘错,风过处,松涛阵阵,如万马奔腾,如海潮涌动,这苍茫的松声与古老的岩石相互唱和,奏响了泰山自然乐章中最雄浑的低音。还有那千年汉柏,相传为汉武帝封禅泰山时所植,至今依然枝繁叶茂,气宇轩昂,见证了无数王朝的更迭与岁月的流转。这些古树名木,是泰山活着的文物,也是自然与历史交融的绿色丰碑。 +如果说自然地理赋予了泰山的骨骼与容颜,那么历史文化则注入了泰山的魂魄与神韵。泰山的历史,是一部浓缩的中华文明史。早在远古时期,泰山就被先民视为通天的神山。大汶口文化和龙山文化的考古发现表明,泰山周围是新石器时代中华先民重要的活动区域,他们在此繁衍生息,观察天象,祭祀山川,开启了泰山崇拜的源头。进入文明时代后,泰山更是被赋予了无与伦比的政治与文化内涵。历代帝王将泰山视为江山社稷的象征,天下太平的符瑞,于是便有了绵延数千年的封禅大典。 +封禅,是泰山独有的、最高规格的国家祭祀大典。所谓“封”,是指在泰山之巅筑土为坛,祭天以报天之功;所谓“禅”,是指在泰山脚下的小山辟基扫地为场,祭地以报地之功。这一仪式并非寻常帝王可为,必须是改朝换代、功成治定、天下太平之后,才有资格告成于天。公元前二一九年,秦始皇统一六国后,率领文武大臣,跋涉千里,登上泰山,行封禅礼,刻石颂德,成为历史上有明确记载的第一位封禅泰山的帝王。此后,汉武帝八次东巡,六次封禅,将泰山祭祀推向鼎盛。汉武帝在泰山立碑无字,任后人评说功过,其胸襟气度,至今令人感慨。汉光武帝刘秀、唐高宗李治与武则天、唐玄宗李隆基、宋真宗赵恒等帝王,都曾在泰山举行过规模宏大的封禅仪式。那些封禅台遗址、御帐坪遗址、古登封台遗址,至今依然静卧在山林之间,虽已荒草丛生,却依然能够唤起人们对那个恢弘时代的追忆与想象。 +除了帝王封禅,历代的文人墨客、名士高僧也在泰山留下了深深的足迹。孔子登泰山而感叹“苛政猛于虎”,这里是他“小天下”的思想高地;司马迁在《史记》中浓墨重彩地记载封禅大典,泰山成为他史笔之下的神圣坐标;曹植、陆机、李白、杜甫、苏轼、元好问,历代诗坛巨擘无不为泰山挥毫泼墨,留下传诵千古的华章。李白“天门一长啸,万里清风来”的豪放,杜甫“造化钟神秀,阴阳割昏晓”的凝练,早已镌刻在每一个中国人的文化记忆之中。明代大旅行家徐霞客两登泰山,其游记成为后世探寻泰山的重要文献。清代康熙、乾隆二帝,更是对泰山情有独钟,康熙帝撰写《泰山龙脉论》,从风水学的角度阐释泰山与国家兴衰的关系;乾隆帝一生十一次登临泰山,留下了大量的诗碑御笔,至今仍能在山间寻见其墨宝。 +泰山更是一座无与伦比的书法艺术宝库。从山麓至山巅,古道两侧,崖壁之上,碑碣林立,刻石遍布,两千余处摩崖石刻如同一部镌刻在山石之上的中国书法史。经石峪的《金刚经》摩崖石刻,是北齐僧人的杰作,每字径尺,隶楷参半,浑厚苍古,被尊为“大字鼻祖”、“榜书之宗”。唐玄宗的《纪泰山铭》,洋洋千言,镌于大观峰崖壁之上,隶书端庄雄浑,盛唐气象扑面而来。还有秦代李斯小篆残碑,虽仅存数字,却是“书同文”历史变革的实物见证。宋代米芾“第一山”的洒脱,明代张钦“观海”的飘逸,清代刘墉、铁保、阮元等名家的墨迹,以及近现代孙中山、毛泽东、郭沫若等伟人的题词,构成了跨越两千余年的书法艺术长廊。登泰山,不仅是身体的跋涉,更是一场与历代先贤跨越时空的对话,每一步都踏在历史的回响之上。 +泰山的人文景观建筑,同样体现了中国古代建筑艺术的卓越成就。以岱庙为核心的泰山古建筑群,是中国现存规模最大、保存最完整的古代山岳祭祀建筑群之一。位于泰山南麓的岱庙,旧称东岳庙,是历代帝王祭祀泰山神的地方,其建筑规制完全仿照帝王宫殿的格局。天贶殿是岱庙的主体建筑,与北京故宫太和殿、曲阜孔庙大成殿并称中国古代三大宫殿式建筑。殿内供奉东岳大帝神像,四壁绘有著名的《泰山神启跸回銮图》,画中人物六百有余,山川林木、宫殿车马,气势恢宏,是宋代壁画的瑰宝。从岱庙向北,沿着绵延九公里的登山御道,分布着红门宫、万仙楼、斗母宫、壶天阁、中天门、五松亭、对松亭、南天门等数十处古建筑。这些建筑依山就势,巧妙利用地形,与自然山水浑然一体,将漫长的登山路线串联成一条神圣的天梯。尤其是十八盘,一千六百余级石阶如同天梯倒挂,两侧悬崖壁立,登临其上,真有“天门云梯”之感。而当穿过南天门,踏上“天街”,仿佛真的走进了天庭仙界,那种精神上的升华与震撼,是任何语言都难以尽述的。 +泰山民俗文化同样源远流长,丰富多彩。泰山石敢当的信仰习俗,遍及海内外华人社区。那一方方镌刻着“石敢当”或“泰山石敢当”字样的石碑,立于村口巷尾、桥头路冲,用以镇宅辟邪、保境安民,至今已有一千余年的传承历史。泰山老奶奶碧霞元君的信仰,更是深入人心。碧霞元君是泰山女神,宋真宗封禅泰山时发现玉女石像,敕建昭真祠供奉,后世累加封号,至明代已成为华北地区影响最大的女神信仰。每年农历四月的东岳庙会,香客云集,商贾辐辏,戏曲杂耍,百戏纷呈,是泰山民俗文化的集中展示。至今,依然有无数的善男信女,不远千里,徒步登山,朝拜碧霞元君,延续着这份源自古老时代的虔诚与敬畏。 +泰山还承载着中华民族融合团结的历史记忆。汉武帝泰山封禅时,南越、东越、西域诸国使节皆随行同祭,显示了大一统王朝的包容与气度。唐代文成公主入藏和亲,行前专程遣使祭祀泰山,祈求国家安宁、民族和睦。金元时期,少数民族统治者入主中原后,同样延续了泰山祭祀的传统,金世宗、元世祖都曾遣官致祭,泰山成为民族融合的文化象征。明永乐年间,女真首领阿哈出率部族朝贡明成祖于泰山脚下,成为东北边疆归附中央的历史见证。泰山见证了中华民族多元一体格局的形成与发展,是各民族文化认同的共同精神家园。 +泰山的文化内涵,还体现在它作为“国山”地位的最终确立。从先秦时期“泰山岩岩,鲁邦所瞻”的诗篇,到汉代“泰山不让土壤,故能成其大”的哲理,再到宋代范仲淹“先天下之忧而忧,后天下之乐而乐”的名句与泰山精神的契合,泰山已不仅仅是一座地理意义上的山,而是国家社稷、民族精神的象征。当国家危难之际,仁人志士常以泰山自励,文天祥“人生自古谁无死,留取丹心照汗青”的铮铮铁骨,于谦“粉骨碎身浑不怕,要留清白在人间”的凛然正气,无不折射出泰山所象征的坚贞不屈。抗日战争时期,泰山成为鲁中抗日根据地的坚强屏障,八路军将士在泰山周边与日寇浴血奋战,书写了可歌可泣的英雄篇章。冯玉祥将军隐居泰山期间,积极宣传抗日救亡,其墓至今安卧在泰山之麓,与山河同在,与日月同辉。1949年中华人民共和国成立后,泰山作为中华民族精神的象征,受到了前所未有的保护与尊崇。泰山被列为国家重点风景名胜区,岱庙、泰山古建筑群被列入全国重点文物保护单位。泰山所承载的自强不息、厚德载物、国泰民安的文化精神,成为中华民族伟大复兴的精神动力。 +泰山的文化与自然价值获得了国际社会的高度认可。1987年,联合国教科文组织世界遗产委员会将泰山列入世界文化与自然双重遗产名录,这是全球首例获此殊荣的项目。世界遗产委员会的评语写道:“泰山在近两千年的历史中,一直是中国艺术家和学者的精神源泉,是古代中国文明与信仰的象征。”这一评价精准地概括了泰山在世界文明格局中的独特地位。2006年,泰山因其独特的地质构造、典型的地质遗迹和重要的地质历史记录,被批准为世界地质公园。至此,泰山成为中国唯一拥有世界文化、自然、地质三重国际桂冠的名山。这是泰山的荣耀,也是中华文明对人类文明宝库作出的杰出贡献。 +进入新时代,泰山的保护与传承工作进入了新的历史阶段。面对每年数百万游客的登山朝圣,如何平衡遗产保护与旅游开发、文化传承与经济发展,成为摆在管理者面前的重要课题。泰山管理部门坚持“保护为主、抢救第一、合理利用、加强管理”的方针,实施了一系列卓有成效的保护工程。岱庙天贶殿宋代壁画数字化保护项目,利用先进技术为千年国宝建立了完整的数字档案;泰山古建筑群修缮工程,严格遵循不改变文物原状的原则,恢复了古建筑的历史风貌;泰山封禅祭祀文化、泰山石敢当习俗、泰山庙会等非物质文化遗产项目,分别被列入国家级非物质文化遗产名录,得到了系统的保护与活态传承。同时,泰山景区积极推进智慧景区建设,通过预约限流、数字化导览、环境监测等手段,实现了遗产保护与游客满意度的双赢。 +2025年3月20日,春分时节,泰山迎来了今年的第100万名游客。这是一个具有象征意义的数据,它表明泰山作为“天下第一山”的永恒魅力。来自五湖四海的游客,或虔诚朝拜,或慕名探访,或挑战自我,每个人都以自己的方式与这座神山对话。清晨五点,日观峰上已站满了等待日出的年轻人;夕阳西下,天街上的游客依然流连忘返。来自广东的大学生小林,背着帐篷徒步登上泰山,他说:“登泰山就像阅读一部浓缩的中国史,每一步都在与古人对话。”年过七旬的退休教师王先生,已第十次登上泰山,他在大观峰前久久伫立:“每次来泰山,都有新的感悟。泰山不老,它永远年轻。”这些普通人的话语,道出了泰山跨越千年依然鲜活的生命力。 +站在玉皇顶,极目四望,云海苍茫,群山俯首。这一刻,你能够真切地理解,为什么泰山能够成为五岳独尊,为什么它被称为中华民族的精神家园。泰山的伟大,不在于它的绝对高度,而在于它与中华文明相伴相生的五千年历程,在于它所承载的历史记忆与民族情感,在于它所昭示的自强不息、厚德载物的文化精神。泰山是一座山,但泰山又不仅仅是一座山。它是大地隆起的史册,是文明镌刻的丰碑,是每一个中国人心中永恒的精神坐标。 +从远古洪荒到人工智能时代,从秦皇汉武到新时代青年,泰山始终以沉静的目光注视着华夏大地的沧桑巨变。它的岩石记录着地球童年的记忆,它的石刻承载着先贤哲思的光芒,它的松涛传唱着千古不朽的诗篇。泰山安,四海皆安。这朴素的话语中,蕴含着中国人对天下太平、国泰民安最深切的期盼与最坚定的信念。当我们登上泰山,实际上是在攀登一座精神的阶梯,是在进行一场穿越时空的文化寻根。泰山就在那里,沉稳、厚重、亘古不变,一如中华民族生生不息的文化血脉,一如东方文明历久弥新的智慧光芒。 +泰山如坐,静观天下风云;泰山如鼎,承载民族精神。这便是岱宗,这便是泰山。 diff --git a/examples/jiuge.py b/examples/jiuge.py index 2acc7fb3..1196fd78 100644 --- a/examples/jiuge.py +++ b/examples/jiuge.py @@ -15,6 +15,8 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python")) +_PAGED_KV_BLOCK_SIZE = 256 + def get_args(): parser = argparse.ArgumentParser(description="run Llama args") @@ -105,6 +107,14 @@ def get_args(): action="store_true", help="use paged cache", ) + + parser.add_argument( + "--paged_kv_block_size", + type=int, + default=256, + help="num tokens each kv block can hold", + ) + parser.add_argument( "--enable-graph", action="store_true", @@ -131,6 +141,19 @@ def get_args(): default=1.0, help="sampling temperature", ) + parser.add_argument( + "--warmup", + action="store_true", + help="Perform a warmup run before benchmarking/inference." + ) + + parser.add_argument( + "--attn", + type=str, + default="default", + choices=["default", "flash-attn"], + help="attention backend to use: 'default' or 'flash-attn'", + ) return parser.parse_args() @@ -146,6 +169,7 @@ def test( top_k=1, top_p=1.0, temperature=1.0, + attn_backend="default", ): model_path = os.path.expanduser(model_path) # ---------------------------------------------------------------------------- # @@ -156,6 +180,7 @@ def test( device=infini_device, distributed_config=DistConfig(tp), enable_graph_compiling=enable_graph, + attention_backend=attn_backend, ) # ---------------------------------------------------------------------------- # # Load Weights @@ -225,7 +250,11 @@ def test( batch_size = 1 if prompts is str else len(prompts) max_total_tokens = max_new_tokens + len(input_ids_list[0]) cache_config = PagedKVCacheConfig( - num_blocks=((max_total_tokens + 15) // 16) * batch_size, block_size=16 + num_blocks=( + (max_total_tokens + (_PAGED_KV_BLOCK_SIZE - 1)) // _PAGED_KV_BLOCK_SIZE + ) + * batch_size, + block_size=_PAGED_KV_BLOCK_SIZE, ) else: batch_size = 1 if prompts is str else len(prompts) @@ -236,6 +265,44 @@ def test( model.reset_cache(cache_config) + # ---------------------------------------------------------------------------- # + # Warmup + # ---------------------------------------------------------------------------- # + if args.warmup: + warmup_steps = 1 + + # Choose a length that approximates the real workload. + # It should be long enough to trigger the correct kernel paths, + # but not so long that warmup becomes unnecessarily expensive. + avg_prompt_len = min(64, max(len(ids) for ids in input_ids_list)) + + # Use truncated versions of real prompts for warmup + warmup_ids = [ + ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids + for ids in input_ids_list + ] + + input_ids_infini = infinicore.from_list(warmup_ids) + + print("=================== warmup start ===================") + + for _ in range(warmup_steps): + _ = model.generate( + input_ids_infini, + GenerationConfig( + max_new_tokens=2, # warmup decode kernel + temperature=temperature, + top_k=top_k, + top_p=top_p, + ), + _measure_and_log_time=False, + ) + + print("=================== warmup done ====================") + + # Reset KV cache + model.reset_cache(cache_config) + # ---------------------------------------------------------------------------- # # Generate # ---------------------------------------------------------------------------- # @@ -295,6 +362,7 @@ def test( ) sys.exit(1) prompts = [args.prompt for _ in range(args.batch_size)] + _PAGED_KV_BLOCK_SIZE = args.paged_kv_block_size model_path = args.model_path max_new_tokens = args.max_new_tokens @@ -318,4 +386,5 @@ def test( top_k=args.top_k, top_p=args.top_p, temperature=args.temperature, + attn_backend=args.attn, ) diff --git a/include/infinicore_infer/cache.h b/include/infinicore_infer/cache.h index c6693914..522f2235 100644 --- a/include/infinicore_infer/cache.h +++ b/include/infinicore_infer/cache.h @@ -3,7 +3,7 @@ #include -__C __export struct KVCache *createKVCache( +__INFINI_C __export struct KVCache *createKVCache( size_t nlayers, size_t max_len, size_t nkvh_, @@ -14,8 +14,8 @@ __C __export struct KVCache *createKVCache( int *dev_ids, size_t ndev); -__C __export struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len); +__INFINI_C __export struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len); -__C __export void dropKVCache(KVCache *kv_cache); +__INFINI_C __export void dropKVCache(KVCache *kv_cache); #endif /* CACHE_H */ diff --git a/include/infinicore_infer/models/deepseek.h b/include/infinicore_infer/models/deepseek.h index 3924c5fe..d7d2e686 100644 --- a/include/infinicore_infer/models/deepseek.h +++ b/include/infinicore_infer/models/deepseek.h @@ -103,26 +103,26 @@ typedef struct { /// @param device 协处理器种类 /// @param ndev 协处理器数量 /// @param dev_ids 协处理器编号,长度为 ndev -__C __export struct DeepSeekV3Model * +__INFINI_C __export struct DeepSeekV3Model * createDeepSeekV3Model(const DeepSeekV3Meta *, const DeepSeekV3Weights *); -__C DeepSeekV3Weights * +__INFINI_C DeepSeekV3Weights * createDeepSeekV3Weights(const DeepSeekV3Meta *meta, infiniDevice_t device, int ndev, const int *dev_ids); -__C __export DeepSeekV3WeightLoader * +__INFINI_C __export DeepSeekV3WeightLoader * createDeepSeekV3WeightLoader(); /// @brief 销毁模型 -__C __export void destroyDeepSeekV3Model(struct DeepSeekV3Model *); +__INFINI_C __export void destroyDeepSeekV3Model(struct DeepSeekV3Model *); -__C __export struct DeepSeekV3Cache * +__INFINI_C __export struct DeepSeekV3Cache * createDeepSeekV3Cache(const struct DeepSeekV3Model *); -__C __export void +__INFINI_C __export void dropDeepSeekV3Cache(const struct DeepSeekV3Model *, struct DeepSeekV3Cache *); @@ -137,7 +137,7 @@ dropDeepSeekV3Cache(const struct DeepSeekV3Model *, /// @param topk 采样 topk(1 表示贪心采样) /// @param topp 采样 topp /// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq -__C __export void +__INFINI_C __export void inferBatchDeepSeekV3(struct DeepSeekV3Model *, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, @@ -153,7 +153,7 @@ inferBatchDeepSeekV3(struct DeepSeekV3Model *, /// @param req_pos 每个请求的起始位置 /// @param kv_caches 每个请求的 KV Cache /// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq -__C __export void +__INFINI_C __export void forwardBatchDeepSeekV3(struct DeepSeekV3Model *, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, diff --git a/include/infinicore_infer/models/jiuge.h b/include/infinicore_infer/models/jiuge.h index ee0a78c0..824ea8a8 100644 --- a/include/infinicore_infer/models/jiuge.h +++ b/include/infinicore_infer/models/jiuge.h @@ -54,7 +54,7 @@ typedef struct /// @param device 协处理器种类 /// @param ndev 协处理器数量 /// @param dev_ids 协处理器编号,长度为 ndev -__C __export struct JiugeModel * +__INFINI_C __export struct JiugeModel * createJiugeModel(const JiugeMeta *, const JiugeWeights *, infiniDevice_t device, @@ -62,7 +62,7 @@ createJiugeModel(const JiugeMeta *, const int *dev_ids); /// @brief 销毁模型 -__C __export void +__INFINI_C __export void destroyJiugeModel(struct JiugeModel *); /// @brief 批次推理一轮,并采样出新的 token @@ -76,7 +76,7 @@ destroyJiugeModel(struct JiugeModel *); /// @param topk 采样 topk(1 表示贪心采样) /// @param topp 采样 topp /// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq -__C __export void +__INFINI_C __export void inferBatchJiuge(struct JiugeModel *, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, @@ -92,7 +92,7 @@ inferBatchJiuge(struct JiugeModel *, /// @param req_pos 每个请求的起始位置 /// @param kv_caches 每个请求的 KV Cache /// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq -__C __export void +__INFINI_C __export void forwardBatchJiuge(struct JiugeModel *, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, diff --git a/include/infinicore_infer/models/jiuge_awq.h b/include/infinicore_infer/models/jiuge_awq.h index de4f94f8..631eb3ce 100644 --- a/include/infinicore_infer/models/jiuge_awq.h +++ b/include/infinicore_infer/models/jiuge_awq.h @@ -25,7 +25,7 @@ typedef struct } JiugeAWQMeta; //////////////////// APIs /////////////////////// -__C __export struct ModelWeights * +__INFINI_C __export struct ModelWeights * createJiugeAWQWeights(const JiugeAWQMeta *, infiniDevice_t device, int ndev, @@ -34,12 +34,12 @@ createJiugeAWQWeights(const JiugeAWQMeta *, /// @param device 协处理器种类 /// @param ndev 协处理器数量 /// @param dev_ids 协处理器编号,长度为 ndev -__C __export struct JiugeAWQModel * +__INFINI_C __export struct JiugeAWQModel * createJiugeAWQModel(const JiugeAWQMeta *, const ModelWeights *); /// @brief 销毁模型 -__C __export void +__INFINI_C __export void destroyJiugeAWQModel(struct JiugeAWQModel *); /// @brief 批次推理一轮,并采样出新的 token @@ -53,7 +53,7 @@ destroyJiugeAWQModel(struct JiugeAWQModel *); /// @param topk 采样 topk(1 表示贪心采样) /// @param topp 采样 topp /// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq -__C __export void +__INFINI_C __export void inferBatchJiugeAWQ(struct JiugeAWQModel *, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, @@ -69,7 +69,7 @@ inferBatchJiugeAWQ(struct JiugeAWQModel *, /// @param req_pos 每个请求的起始位置 /// @param kv_caches 每个请求的 KV Cache /// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq -__C __export void +__INFINI_C __export void forwardBatchJiugeAWQ(struct JiugeAWQModel *, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, diff --git a/include/infinicore_infer/weights_loader.h b/include/infinicore_infer/weights_loader.h index 90737157..82eafe59 100644 --- a/include/infinicore_infer/weights_loader.h +++ b/include/infinicore_infer/weights_loader.h @@ -5,10 +5,10 @@ struct ModelWeights; -__C __export void +__INFINI_C __export void loadModelWeight(struct ModelWeights *weights, const char *name, void *data); -__C __export void +__INFINI_C __export void loadModelWeightDistributed(struct ModelWeights *weights, const char *name, void *data, int *ranks, int nrank); #endif // WEIGHTS_LOADER_H diff --git a/python/infinilm/cache/cache.py b/python/infinilm/cache/cache.py index e0fe8168..354309e1 100644 --- a/python/infinilm/cache/cache.py +++ b/python/infinilm/cache/cache.py @@ -17,7 +17,7 @@ class PagedKVCacheConfig(CacheConfig, _infinilm.PagedKVCacheConfig): def __init__( self, num_blocks: int, - block_size: int = 16, + block_size: int = 256, ): _infinilm.PagedKVCacheConfig.__init__( self, diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 6dfcbbcd..8b614980 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -29,26 +29,20 @@ def __init__( distributed_config=DistConfig(1), cache_config=None, enable_graph_compiling=False, + attention_backend="default", ): self.config = AutoConfig.from_pretrained(model_path) if device is None: device = infinicore.device() - - # super().__init__( - # self.config, - # distributed_config._underlying, - # device._underlying.type, - # cache_config, - # enable_graph_compiling, - # ) - + super().__init__( model_path, distributed_config._underlying, device._underlying.type, cache_config, enable_graph_compiling, + attention_backend, ) self.use_cache = False @@ -65,6 +59,7 @@ def forward( past_kv_lengths=None, total_kv_lengths=None, input_offsets=None, + cu_seqlens=None, block_tables=None, slot_mapping=None, temperature=None, @@ -82,6 +77,7 @@ def forward( ) input_offsets = input_offsets._underlying if input_offsets is not None else None block_tables = block_tables._underlying if block_tables is not None else None + cu_seqlens = cu_seqlens._underlying if cu_seqlens is not None else None slot_mapping = slot_mapping._underlying if slot_mapping is not None else None return infinicore.Tensor( @@ -93,6 +89,7 @@ def forward( past_sequence_lengths=past_kv_lengths, total_sequence_lengths=total_kv_lengths, input_offsets=input_offsets, + cu_seqlens=cu_seqlens, block_tables=block_tables, slot_mapping=slot_mapping, temperature=temperature, @@ -109,7 +106,6 @@ def generate( generation_config, *, _measure_and_log_time=False, - paged_block_size=16, ): if generation_config.eos_token_id is None: eos_token_id = self.config.eos_token_id @@ -133,6 +129,7 @@ def generate( block_tables = None max_blocks_per_batch = 0 if self.enable_paged_attn: + paged_block_size = self.get_cache_config().block_size() max_blocks_per_batch = ( initial_seqlen + generation_config.max_new_tokens + paged_block_size - 1 ) // paged_block_size @@ -143,7 +140,7 @@ def generate( ] block_tables = infinicore.from_list( block_tables_list, - dtype=infinicore.int64, + dtype=infinicore.int32, ) for iter in range(0, generation_config.max_new_tokens): @@ -196,14 +193,17 @@ def generate( slot_mapping = None past_kv_lengths = infinicore.from_list( - [past_seq_len] * batch_size, dtype=infinicore.int64 + [past_seq_len] * batch_size, dtype=infinicore.int32 ) total_kv_lengths = infinicore.from_list( - [past_seq_len + seq_len] * batch_size, dtype=infinicore.int64 + [past_seq_len + seq_len] * batch_size, dtype=infinicore.int32 + ) + cu_seqlens = infinicore.from_list( + [(past_seq_len + seq_len) * i for i in range(batch_size + 1)], + dtype=infinicore.int32, ) - input_offsets = infinicore.from_list( - [seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int64 + [seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int32 ) output_id = self( @@ -212,6 +212,7 @@ def generate( past_kv_lengths=past_kv_lengths, total_kv_lengths=total_kv_lengths, input_offsets=input_offsets, + cu_seqlens=cu_seqlens, block_tables=block_tables, slot_mapping=slot_mapping, temperature=generation_config.temperature, diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index b82b6f48..7293a2fe 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -13,6 +13,9 @@ from typing import List, Optional, Union, AsyncIterator from dataclasses import dataclass +from transformers import AutoTokenizer +from tokenizers import decoders as _dec + import infinicore from infinilm.llm.request import ( @@ -29,8 +32,6 @@ from infinilm.infer_engine import InferEngine from infinilm.cache.cache import PagedKVCacheConfig, StaticKVCacheConfig from infinilm.modeling_utils import load_model_state_dict_by_file -from transformers import AutoTokenizer -from tokenizers import decoders as _dec logger = logging.getLogger(__name__) @@ -54,6 +55,7 @@ class EngineConfig: top_p: Default top-p sampling parameter. top_k: Default top-k sampling parameter. enable_graph: Whether to enable graph compiling. + attn_backend: Attention backend to use ('default', 'flash-attn'). """ model_path: str @@ -63,13 +65,14 @@ class EngineConfig: cache_type: str = "paged" # "paged" or "static" max_batch_size: int = 16 max_tokens: int = 4096 - num_blocks: int = 8 * 1024 - block_size: int = 16 + num_blocks: int = 512 + block_size: int = 256 max_cache_len: int = 4096 temperature: float = 1.0 top_p: float = 0.8 top_k: int = 1 enable_graph: bool = False + attn_backend: str = "default" class LLMEngine: @@ -87,6 +90,7 @@ def __init__(self, config: EngineConfig): device=self.device, distributed_config=DistConfig(config.tensor_parallel_size), enable_graph_compiling=config.enable_graph, + attention_backend=config.attn_backend, ) # Load model weights @@ -222,16 +226,16 @@ def _prepare_model_input(self, model_input_dict: dict) -> dict: if value is None: # Skip None values (block_tables/slot_mapping for static cache) model_input[key] = None + elif key in ["input_ids", "position_ids", "slot_mapping"]: + model_input[key] = infinicore.from_list(value, dtype=infinicore.int64) elif key in [ - "input_ids", - "position_ids", "past_kv_lengths", "total_kv_lengths", "input_offsets", - "slot_mapping", + "cu_seqlens", "block_tables", ]: - model_input[key] = infinicore.from_list(value, dtype=infinicore.int64) + model_input[key] = infinicore.from_list(value, dtype=infinicore.int32) else: # temperature, top_k, top_p, etc. model_input[key] = value @@ -249,47 +253,36 @@ def _update_requests( self.scheduler.cache_manager.reset_req_blocks() for req, token_id in zip(requests, sampled_tokens): - req.generated_token_ids.append(token_id) + + if req.is_aborted(): + logger.info( + f"Request {req.request_id} aborted by client, skipping update" + ) + continue + if req.is_prefill: req.is_prefill = False + + req.generated_token_ids.append(token_id) + decoded_text = self.detokenize(req.generated_token_ids) + req.generated_text = decoded_text + holds_back_incomplete_utf8 = bool(decoded_text) and decoded_text.endswith( + "\ufffd" + ) + + is_finished = self._check_request_finished(req, token_id) + # vLLM-style replacement character handling is primarily relevant for streaming. # For offline generation (no output queue), keep the fast incremental path. if req._output_queue is None: - token_text = self.detokenize([token_id]) - req.generated_text += token_text - else: - # Streaming path: compute delta from a full decode so we can hold back - # trailing '\ufffd' (likely an incomplete UTF-8 sequence). - decoded_text = self.detokenize(req.generated_token_ids) - - finished_now = False - # Update generated_text to the latest decode (used for stop-string checks and debugging) - req.generated_text = decoded_text - - if self._check_request_finished(req, token_id): + if is_finished: + if holds_back_incomplete_utf8: + req.generated_text = decoded_text[:-1] req.mark_finished(req.finish_reason) - finished_now = True - - # Remove stop string from generated_text if STOP_STRING finish reason - if req.finish_reason == FinishReason.STOP_STRING: - stop_strings = req.sampling_params.stop or [] - for stop_str in stop_strings: - if decoded_text.endswith(stop_str): - # Remove the stop string from the end - decoded_text = decoded_text[: -len(stop_str)] - req.generated_text = decoded_text - break - - holds_back_incomplete_utf8 = bool( - decoded_text - ) and decoded_text.endswith("\ufffd") - - # vLLM-style: hold back only if we are not on the final chunk. - # Suppress output when finish reason is LENGTH or STOP_STRING. - # Root cause fix: When STOP_STRING is detected, we suppress output for the token - # that completes the stop string, preventing additional tokens from being output. - if (holds_back_incomplete_utf8 and not finished_now) or ( - finished_now + + else: + if (holds_back_incomplete_utf8 and not is_finished) or ( + is_finished and req.finish_reason in (FinishReason.LENGTH, FinishReason.STOP_STRING) ): @@ -300,30 +293,29 @@ def _update_requests( if token_text: req._stream_last_yielded_length = len(decoded_text) - # For non-streaming, finish checks happen here. - if req._output_queue is None and self._check_request_finished( - req, token_id - ): - req.mark_finished(req.finish_reason) - # Remove stop string from generated_text if STOP_STRING finish reason - if req.finish_reason == FinishReason.STOP_STRING: - stop_strings = req.sampling_params.stop or [] - for stop_str in stop_strings: - if req.generated_text.endswith(stop_str): - # Remove the stop string from the end - req.generated_text = req.generated_text[: -len(stop_str)] - break - # Put output in queue if it exists (for async streaming) - if req._output_queue is not None: + if is_finished: + req.mark_finished(req.finish_reason) output = TokenOutput( request_id=req.request_id, token_id=token_id, token_text=token_text, - finished=req.is_finished(), - finish_reason=req.finish_reason, + finished=is_finished, + finish_reason=req.finish_reason if is_finished else None, generated_text=req.generated_text, ) - req.output_queue.sync_q.put(output) + if req.is_aborted(): + logger.info( + f"Request {req.request_id} aborted before putting token" + ) + continue + try: + req.output_queue.sync_q.put(output) + except Exception as e: + logger.warning( + f"Failed to put token for {req.request_id}: {e}. " + f"Likely due to client disconnecting or request cancelation." + ) + continue self.scheduler.complete_requests(requests) @@ -341,9 +333,11 @@ def _check_request_finished(self, req: InferenceRequest, token_id: int) -> bool: return True # Check stop strings + # Remove stop string from generated_text if STOP_STRING finish reason stop_strings = req.sampling_params.stop or [] for stop_str in stop_strings: if req.generated_text.endswith(stop_str): + req.generated_text = req.generated_text[: -len(stop_str)] req.finish_reason = FinishReason.STOP_STRING return True @@ -385,13 +379,14 @@ def __init__( cache_type: str = "paged", max_batch_size: int = 16, max_tokens: int = 4096, - num_blocks: int = 8 * 1024, - block_size: int = 16, + num_blocks: int = 512, + block_size: int = 256, max_cache_len: int = 4096, temperature: float = 1.0, top_p: float = 0.8, top_k: int = 1, enable_graph: bool = False, + attn_backend: str = "default", ): """Initialize LLM. @@ -410,6 +405,7 @@ def __init__( top_p: Default top-p sampling parameter. top_k: Default top-k sampling parameter. enable_graph: Whether to enable graph compiling. + attn_backend: Attention backend to use ('default', 'flash-attn'). """ config = EngineConfig( model_path=model_path, @@ -426,6 +422,7 @@ def __init__( top_p=top_p, top_k=top_k, enable_graph=enable_graph, + attn_backend=attn_backend, ) self.engine = LLMEngine(config) self.config = config @@ -538,13 +535,14 @@ def __init__( cache_type: str = "paged", max_batch_size: int = 16, max_tokens: int = 512, - num_blocks: int = 8 * 1024, - block_size: int = 16, + num_blocks: int = 512, + block_size: int = 256, max_cache_len: int = 4096, temperature: float = 1.0, top_p: float = 0.8, top_k: int = 1, enable_graph: bool = False, + attn_backend: str = "default", ): """Initialize AsyncLLMEngine. @@ -563,6 +561,7 @@ def __init__( top_p: Default top-p sampling parameter. top_k: Default top-k sampling parameter. enable_graph: Whether to enable graph compiling. + attn_backend: Attention backend to use ('default', 'flash-attn'). """ config = EngineConfig( model_path=model_path, @@ -579,6 +578,7 @@ def __init__( top_p=top_p, top_k=top_k, enable_graph=enable_graph, + attn_backend=attn_backend, ) self.engine = LLMEngine(config) self.config = config @@ -732,10 +732,19 @@ async def stream_request( start = time.time() while True: - if request.is_finished() and request.output_queue.async_q.empty(): - break - try: + if request_timeout and time.time() - start > float(request_timeout): + request.mark_timeout() + yield TokenOutput( + request_id=request.request_id, + token_id=-1, + token_text="", + finished=True, + finish_reason=FinishReason.TIMEOUT, + generated_text=request.generated_text, + ) + break + token_output = await asyncio.wait_for( request.output_queue.async_q.get(), timeout=timeout ) @@ -747,26 +756,28 @@ async def stream_request( if token_output.finished: break except asyncio.TimeoutError: - # Enforce request-level timeout even if no tokens are produced. - if request_timeout is not None: - now = time.time() - if now - start > float(request_timeout): - request.mark_timeout() - yield TokenOutput( - request_id=request.request_id, - token_id=-1, - token_text="", - finished=True, - finish_reason=FinishReason.TIMEOUT, - generated_text=request.generated_text, - ) - break - if request.is_finished(): + logger.warning( + f"Timeout while waiting for token from request {request.request_id}" + ) + if request.is_aborted(): + while not request.output_queue.async_q.empty(): + try: + token_output = request.output_queue.async_q.get_nowait() + request.output_queue.async_q.task_done() + yield token_output + except asyncio.QueueEmpty: + break + + yield TokenOutput( + request_id=request.request_id, + token_id=-1, + token_text="", + finished=True, + finish_reason=request.finish_reason, + generated_text=request.generated_text, + ) break continue - except asyncio.CancelledError: - request.mark_canceled() - break except Exception as e: - logger.error(f"Error streaming request {request.request_id}: {e}") - await asyncio.sleep(0.01) + logger.error(f"Error while streaming request {request.request_id}: {e}") + break diff --git a/python/infinilm/llm/request.py b/python/infinilm/llm/request.py index 224828d1..59d2ea15 100644 --- a/python/infinilm/llm/request.py +++ b/python/infinilm/llm/request.py @@ -7,9 +7,13 @@ from typing import List, Optional, Any import time import janus +import asyncio +import logging from infinilm.llm.sampling_params import SamplingParams +logger = logging.getLogger(__name__) + class RequestStatus(Enum): """Status of an inference request.""" @@ -143,6 +147,7 @@ def __init__( # Output management (for async streaming) self._output_queue: Optional[janus.Queue] = None + self._aborted = False # Streaming helpers (vLLM-style UTF-8 buffering at the chunking layer) # Used by the engine to compute "delta" text chunks from a full decode. @@ -185,6 +190,14 @@ def is_finished(self) -> bool: RequestStatus.TIMEOUT, ] + def abort(self): + """Signal that the request has been aborted and should stop generation.""" + self._aborted = True + + def is_aborted(self) -> bool: + """Check if the request has been aborted.""" + return self._aborted + def mark_finished(self, reason: FinishReason): """Mark the request as finished with the given reason.""" self.status = RequestStatus.FINISHED @@ -193,18 +206,21 @@ def mark_finished(self, reason: FinishReason): def mark_failed(self, reason: FinishReason = FinishReason.ERROR): """Mark the request as failed.""" + self.abort() self.status = RequestStatus.FAILED self.finish_reason = reason self.finished_time = time.time() def mark_canceled(self): """Mark the request as canceled.""" + self.abort() self.status = RequestStatus.CANCELED self.finish_reason = FinishReason.CANCELED self.finished_time = time.time() def mark_timeout(self): """Mark the request as timed out.""" + self.abort() self.status = RequestStatus.TIMEOUT self.finish_reason = FinishReason.TIMEOUT self.finished_time = time.time() @@ -212,9 +228,25 @@ def mark_timeout(self): async def close(self): """Close the output queue and clean up resources.""" if self._output_queue is not None: - await self._output_queue.async_q.join() + self.abort() + try: + while not self._output_queue.async_q.empty(): + try: + self._output_queue.async_q.get_nowait() + self._output_queue.async_q.task_done() + except asyncio.QueueEmpty: + break + except Exception as e: + logger.error( + f"Error while clearing output queue for request {self.request_id}: {e}" + ) + pass + self._output_queue.close() - await self._output_queue.wait_closed() + try: + await asyncio.wait_for(self._output_queue.wait_closed(), timeout=0.5) + except asyncio.TimeoutError: + logger.warning("wait_closed timeout, force close") def to_request_output(self) -> RequestOutput: """Convert to RequestOutput for external use.""" diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index 91e9c0a1..33ea67d0 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -101,6 +101,9 @@ def build_model_inputs( max_block_table_len - len(req.block_table) ) block_tables.append(padded_block_table) + cu_seqlens = [0] + for l in seq_lens: + cu_seqlens.append(cu_seqlens[-1] + l) return { "input_ids": [tokens], @@ -108,6 +111,7 @@ def build_model_inputs( "past_kv_lengths": cached_lens, "total_kv_lengths": seq_lens, "input_offsets": seq_offsets, + "cu_seqlens": cu_seqlens, "block_tables": block_tables, "slot_mapping": slot_mapping, "temperature": temperature, @@ -128,8 +132,8 @@ class Scheduler: def __init__( self, max_batch_size: int = 16, - num_blocks: int = 8 * 1024, - block_size: int = 16, + num_blocks: int = 512, + block_size: int = 256, ): self.waiting_queue = janus.Queue() self.running_queue = janus.Queue() diff --git a/python/infinilm/llm/static_scheduler.py b/python/infinilm/llm/static_scheduler.py index 82300c6a..e7336242 100644 --- a/python/infinilm/llm/static_scheduler.py +++ b/python/infinilm/llm/static_scheduler.py @@ -7,7 +7,12 @@ import janus from typing import List, Optional -from infinilm.llm.request import RequestStatus, InferenceRequest, FinishReason +from infinilm.llm.request import ( + RequestStatus, + InferenceRequest, + FinishReason, + TokenOutput, +) logger = logging.getLogger(__name__) @@ -70,6 +75,7 @@ def build_model_inputs( "past_kv_lengths": [past_kv_len], "total_kv_lengths": [total_kv_len], "input_offsets": input_offsets, + "cu_seqlens": [0, total_kv_len], "block_tables": None, "slot_mapping": None, "temperature": temperature, @@ -115,6 +121,21 @@ def schedule(self) -> Optional[StaticSchedulerOutput]: ) self.running_request = None req.mark_failed(FinishReason.LENGTH) + output = TokenOutput( + request_id=req.request_id, + token_id=-1, + token_text="", + finished=True, + finish_reason=req.finish_reason, + generated_text=req.generated_text, + ) + try: + req.output_queue.sync_q.put(output) + except Exception as e: + logger.warning( + f"Failed to put completion token for {req.request_id}: {e}. " + f"Likely due to client disconnecting or request cancelation." + ) continue return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=False) @@ -137,6 +158,21 @@ def schedule(self) -> Optional[StaticSchedulerOutput]: ) req.mark_failed(FinishReason.LENGTH) + output = TokenOutput( + request_id=req.request_id, + token_id=-1, + token_text="", + finished=True, + finish_reason=req.finish_reason, + generated_text=req.generated_text, + ) + try: + req.output_queue.sync_q.put(output) + except Exception as e: + logger.warning( + f"Failed to put completion token for {req.request_id}: {e}. " + f"Likely due to client disconnecting or request cancelation." + ) continue req.status = RequestStatus.RUNNING diff --git a/python/infinilm/server/inference_server.py b/python/infinilm/server/inference_server.py index a6197dfe..88afd388 100644 --- a/python/infinilm/server/inference_server.py +++ b/python/infinilm/server/inference_server.py @@ -11,6 +11,7 @@ import uvicorn import logging import os +import asyncio from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse @@ -98,8 +99,8 @@ def __init__( cache_type: str = "paged", max_tokens: int = 4096, max_batch_size: int = 16, - num_blocks: int = 8 * 1024, - block_size: int = 16, + num_blocks: int = 512, + block_size: int = 256, max_cache_len: int = 4096, temperature: float = 1.0, top_p: float = 0.8, @@ -107,6 +108,7 @@ def __init__( host: str = "0.0.0.0", port: int = 8000, enable_graph: bool = False, + attn_backend: str = "default", ): """Initialize inference server. @@ -127,6 +129,7 @@ def __init__( host: Server host address. port: Server port number. enable_graph: Whether to enable graph compiling. + attn_backend: Attention backend to use ('default', 'flash-attn'). """ self.model_path = model_path # vLLM-like served model id: directory name of model_path @@ -146,6 +149,7 @@ def __init__( self.host = host self.port = port self.enable_graph = enable_graph + self.attn_backend = attn_backend self.engine: AsyncLLMEngine = None @@ -176,6 +180,7 @@ async def lifespan(app: FastAPI): top_p=self.top_p, top_k=self.top_k, enable_graph=self.enable_graph, + attn_backend=self.attn_backend, ) self.engine.start() logger.info(f"Engine initialized with model at {self.model_path}") @@ -351,6 +356,12 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) timeout=DEFAULT_STREAM_TIMEOUT, request_timeout=DEFAULT_REQUEST_TIMEOUT, ): + # Check client disconnect + if await http_request.is_disconnected(): + logger.info(f"Client disconnected for request {request_id}") + req.mark_canceled() + break + # If stream_request enforces timeout, we can just surface the state to the client. if token_output.finish_reason == FinishReason.TIMEOUT: logger.warning( @@ -368,12 +379,6 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) yield f"data: {error_chunk}\n\n" break - # Check client disconnect - if await http_request.is_disconnected(): - logger.info(f"Client disconnected for request {request_id}") - req.mark_canceled() - break - # Skip EOS token text for OpenAI API compatibility # Check if this token is an EOS token by comparing token_id with eos_token_ids eos_token_ids = self.engine.engine.eos_token_ids @@ -404,6 +409,12 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) yield f"data: {chunk}\n\n" break + except asyncio.CancelledError: + logger.info(f"Request {request_id} was cancelled") + if req: + req.mark_canceled() + raise + except Exception as e: logger.error(f"Stream error for {request_id}: {e}", exc_info=True) if req: @@ -451,23 +462,23 @@ async def _chat(self, request_id: str, data: dict, http_request: Request): timeout=DEFAULT_STREAM_TIMEOUT, request_timeout=DEFAULT_REQUEST_TIMEOUT, ): - # Request-level timeout is handled inside stream_request. - if token_output.finish_reason == FinishReason.TIMEOUT: - logger.warning(f"Request {request_id} timed out") - break - # Check client disconnect if await http_request.is_disconnected(): logger.info(f"Client disconnected for request {request_id}") req.mark_canceled() break + # Request-level timeout is handled inside stream_request. + if token_output.finish_reason == FinishReason.TIMEOUT: + logger.warning(f"Request {request_id} timed out") + break + # Skip EOS token text for OpenAI API compatibility # Check if this token is an EOS token by comparing token_id with eos_token_ids eos_token_ids = self.engine.engine.eos_token_ids is_eos_token = eos_token_ids and token_output.token_id in eos_token_ids - if not is_eos_token: + if not is_eos_token and token_output.token_text: output_text += token_output.token_text if token_output.finished: @@ -488,6 +499,12 @@ async def _chat(self, request_id: str, data: dict, http_request: Request): ) return response + except asyncio.CancelledError: + logger.info(f"Request {request_id} was cancelled") + if req: + req.mark_canceled() + raise + except Exception as e: logger.error(f"Chat error for {request_id}: {e}", exc_info=True) if req: @@ -555,13 +572,13 @@ def parse_args(): parser.add_argument( "--num_blocks", type=int, - default=8 * 1024, + default=512, help="Number of blocks for KV cache (paged cache only)", ) parser.add_argument( "--block_size", type=int, - default=16, + default=256, help="Block size for KV cache (paged cache only)", ) parser.add_argument( @@ -594,11 +611,19 @@ def parse_args(): parser.add_argument("--iluvatar", action="store_true", help="Use Iluvatar device") parser.add_argument("--cambricon", action="store_true", help="Use Cambricon device") parser.add_argument("--ali", action="store_true", help="Use Ali PPU device") + parser.add_argument("--hygon", action="store_true", help="Use Hygon DCU device") parser.add_argument( "--enable-graph", action="store_true", help="Enable graph compiling", ) + parser.add_argument( + "--attn", + type=str, + default="default", + choices=["default", "flash-attn"], + help="Attention backend to use: 'default' or 'flash-attn'", + ) parser.add_argument( "--log_level", type=str, @@ -631,15 +656,17 @@ def main(): device = "mlu" elif args.ali: device = "cuda" + elif args.hygon: + device = "cuda" else: print( - "Usage: python infinilm.server.inference_server [--cpu | --nvidia | --qy | --metax | --moore | --iluvatar | --cambricon | --ali] " + "Usage: python infinilm.server.inference_server [--cpu | --nvidia | --qy | --metax | --moore | --iluvatar | --cambricon | --ali | --hygon] " "--model_path= --max_tokens=MAX_TOKENS --max_batch_size=MAX_BATCH_SIZE" "\n" "Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ " "--max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1" "\n" - "Optional: --enable-paged-attn --enable-graph" + "Optional: --enable-paged-attn --enable-graph --attn=default" ) sys.exit(1) @@ -660,6 +687,7 @@ def main(): host=args.host, port=args.port, enable_graph=args.enable_graph, + attn_backend=args.attn, ) server.start() diff --git a/scripts/test_perf.py b/scripts/test_perf.py index a6b26f3b..6a33d8f0 100644 --- a/scripts/test_perf.py +++ b/scripts/test_perf.py @@ -4,7 +4,6 @@ import argparse import random - PROMPTS = [ "如果猫能写诗,它们会写些什么?", "描述一个没有重力的世界。", @@ -25,11 +24,11 @@ "如果你可以变成任何一种动物,你会选择什么?", "描述一个由机器人统治的未来世界。", "如果你能与任何虚构角色成为朋友,你会选择谁?", - "想象一下,如果每个人都能读懂他人的思想。" + "想象一下,如果每个人都能读懂他人的思想。", ] -NUM_REQUESTS = 10 -CONCURRENCY = 5 +NUM_REQUESTS = 64 +CONCURRENCY = 20 API_URL = "http://127.0.0.1:8000" MODEL = "FM9G-7B" @@ -43,14 +42,14 @@ async def benchmark_user(client, semaphore, queue, results, user_id, verbose): break question = random.choice(PROMPTS) - try: + try: print(f"🚀 User#{user_id} Sending request #{task_id}") start_time = time.time() stream = await client.chat.completions.create( model=MODEL, messages=[{"role": "user", "content": question}], - stream=True + stream=True, ) first_token_time = None @@ -71,19 +70,33 @@ async def benchmark_user(client, semaphore, queue, results, user_id, verbose): ttft = first_token_time - start_time if first_token_time else None elapsed_time = end_time - start_time if start_time else None - ms_per_token = (elapsed_time / total_tokens * 1000) if total_tokens > 0 and elapsed_time else None - tokens_per_second = total_tokens / elapsed_time if elapsed_time > 0 else 0 + ms_per_token = ( + (elapsed_time / total_tokens * 1000) + if total_tokens > 0 and elapsed_time + else None + ) + tokens_per_second = ( + total_tokens / elapsed_time if elapsed_time > 0 else 0 + ) answer = "".join(answer_chunks) - results.append((total_tokens, elapsed_time, tokens_per_second, ttft, ms_per_token)) + results.append( + (total_tokens, elapsed_time, tokens_per_second, ttft, ms_per_token) + ) if verbose: print(f"\n📝 Request #{task_id} (User #{user_id})") - print(f" ⏱ 首字延迟 TTFT: {ttft:.3f}s") - print(f" ⏱ 总耗时: {elapsed_time:.3f}s") + if ttft is not None: + print(f" ⏱ 首字延迟 TTFT: {ttft:.3f}s") + if elapsed_time is not None: + print(f" ⏱ 总耗时: {elapsed_time:.3f}s") + print(f" 🔤 解码 token 总数: {total_tokens}") - print(f" 📏 平均 token 解码时间: {ms_per_token:.2f} ms/token") + if ms_per_token is not None: + print(f" 📏 平均 token 解码时间: {ms_per_token:.2f} ms/token") + else: + print(f" 📏 平均 token 解码时间: N/A (no token generated)") print(f" ❓ 提问: {question}") print(f" 💬 回答: {answer}\n") @@ -92,6 +105,8 @@ async def benchmark_user(client, semaphore, queue, results, user_id, verbose): if verbose: print(f"\n⚠️ Request #{task_id} (User #{user_id}) FAILED:") print(f" ❌ Error: {e}\n") + queue.task_done() + async def run_benchmark(verbose=False): client = AsyncOpenAI(base_url=API_URL, api_key="default") @@ -104,7 +119,9 @@ async def run_benchmark(verbose=False): await queue.put(None) users = [ - asyncio.create_task(benchmark_user(client, semaphore, queue, results, user_id, verbose)) + asyncio.create_task( + benchmark_user(client, semaphore, queue, results, user_id, verbose) + ) for user_id in range(CONCURRENCY) ] @@ -121,11 +138,19 @@ async def run_benchmark(verbose=False): ms_per_token_list = [r[4] for r in results if r and r[4] is not None] successful_requests = len(results) - requests_per_second = successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0 + requests_per_second = ( + successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0 + ) avg_latency = sum(latencies) / len(latencies) if latencies else 0 - avg_tokens_per_second = sum(tokens_per_second_list) / len(tokens_per_second_list) if tokens_per_second_list else 0 + avg_tokens_per_second = ( + sum(tokens_per_second_list) / len(tokens_per_second_list) + if tokens_per_second_list + else 0 + ) avg_ttft = sum(ttft_list) / len(ttft_list) if ttft_list else 0 - avg_ms_per_token = sum(ms_per_token_list) / len(ms_per_token_list) if ms_per_token_list else None + avg_ms_per_token = ( + sum(ms_per_token_list) / len(ms_per_token_list) if ms_per_token_list else None + ) width_label = 24 sep = "-" * 60 @@ -142,7 +167,9 @@ async def run_benchmark(verbose=False): print(f"{'Average latency':<{width_label}}: {avg_latency:.2f} s") print(f"{'Average TTFT':<{width_label}}: {avg_ttft:.2f} s") print(f"{'Avg time per token':<{width_label}}: {avg_ms_per_token:.2f} ms/token") - print(f"{'Avg Token generation speed':<{width_label}}: {avg_tokens_per_second:.2f} tokens/s") + print( + f"{'Avg Token generation speed':<{width_label}}: {avg_tokens_per_second:.2f} tokens/s" + ) if __name__ == "__main__": @@ -150,6 +177,4 @@ async def run_benchmark(verbose=False): parser.add_argument("--verbose", action="store_true") args = parser.parse_args() - asyncio.run(run_benchmark( - args.verbose - )) + asyncio.run(run_benchmark(args.verbose)) diff --git a/src/cache_manager/kvcache.cpp b/src/cache_manager/kvcache.cpp index 1abb5585..99d07dfa 100644 --- a/src/cache_manager/kvcache.cpp +++ b/src/cache_manager/kvcache.cpp @@ -1,6 +1,6 @@ #include "../cache.hpp" -__C struct KVCache *createKVCache( +__INFINI_C struct KVCache *createKVCache( size_t nlayers, size_t max_len, size_t nkvh_, @@ -31,7 +31,7 @@ __C struct KVCache *createKVCache( return cache; } -__C struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len) { +__INFINI_C struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len) { auto ndev = kv_cache->k.size(); auto nlayers = kv_cache->k[0].size(); auto device = kv_cache->k[0][0]->deviceType(); @@ -65,7 +65,7 @@ __C struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len) { return new_kv_cache; } -__C void dropKVCache(KVCache *kv_cache) { +__INFINI_C void dropKVCache(KVCache *kv_cache) { auto ndev = kv_cache->k.size(); auto nlayers = kv_cache->k[0].size(); auto device = kv_cache->k[0][0]->deviceType(); diff --git a/src/dataloader/weights_loader.cpp b/src/dataloader/weights_loader.cpp index e5526cb6..71e63486 100644 --- a/src/dataloader/weights_loader.cpp +++ b/src/dataloader/weights_loader.cpp @@ -78,7 +78,7 @@ std::shared_ptr Loader::get(const std::string &name, int rank) { } // namespace infinicore::weights -__C void +__INFINI_C void loadModelWeight(struct ModelWeights *weights_, const char *name, void *data) { std::string name_str(name); auto weights = reinterpret_cast(weights_); diff --git a/src/models/deepseek_v3/deepseek_v3.cpp b/src/models/deepseek_v3/deepseek_v3.cpp index 2c463035..db22d87d 100644 --- a/src/models/deepseek_v3/deepseek_v3.cpp +++ b/src/models/deepseek_v3/deepseek_v3.cpp @@ -431,7 +431,7 @@ void inferDeviceBatch(const DeepSeekV3Meta &meta, DeepSeekV3DeviceResource &rsrc } } -__C void +__INFINI_C void inferBatchDeepSeekV3(struct DeepSeekV3Model *model, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, @@ -464,7 +464,7 @@ inferBatchDeepSeekV3(struct DeepSeekV3Model *model, } } -__C void +__INFINI_C void forwardBatchDeepSeekV3(struct DeepSeekV3Model *model, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, @@ -563,14 +563,14 @@ DeepSeekV3Model::DeepSeekV3Model(const DeepSeekV3Meta *_meta, const DeepSeekV3We } } -__C struct DeepSeekV3Model * +__INFINI_C struct DeepSeekV3Model * createDeepSeekV3Model(const DeepSeekV3Meta *_meta, const DeepSeekV3Weights *weights) { DeepSeekV3Model *model = new DeepSeekV3Model(_meta, weights); return model; } -__C void +__INFINI_C void destroyDeepSeekV3Model(struct DeepSeekV3Model *model) { auto ndev = model->dev_resources.size(); diff --git a/src/models/deepseek_v3/deepseek_v3_cache.cpp b/src/models/deepseek_v3/deepseek_v3_cache.cpp index a177fd8c..d6448a61 100644 --- a/src/models/deepseek_v3/deepseek_v3_cache.cpp +++ b/src/models/deepseek_v3/deepseek_v3_cache.cpp @@ -1,6 +1,6 @@ #include "deepseek_v3_impl.hpp" -__C struct DeepSeekV3Cache * +__INFINI_C struct DeepSeekV3Cache * createDeepSeekV3Cache(const struct DeepSeekV3Model *model) { DeepSeekV3Cache *cache = new DeepSeekV3Cache(); auto ndev = model->dev_resources.size(); @@ -25,7 +25,7 @@ createDeepSeekV3Cache(const struct DeepSeekV3Model *model) { return cache; } -__C void +__INFINI_C void dropDeepSeekV3Cache(const struct DeepSeekV3Model *model, struct DeepSeekV3Cache *cache) { auto ndev = model->dev_resources.size(); diff --git a/src/models/deepseek_v3/deepseek_v3_weight.cpp b/src/models/deepseek_v3/deepseek_v3_weight.cpp index 846af633..20a8851d 100644 --- a/src/models/deepseek_v3/deepseek_v3_weight.cpp +++ b/src/models/deepseek_v3/deepseek_v3_weight.cpp @@ -436,7 +436,7 @@ static DeepSeekV3WeightLoader weight_loader = { .load_mlp_experts = load_mlp_experts, }; -__C DeepSeekV3Weights * +__INFINI_C DeepSeekV3Weights * createDeepSeekV3Weights(const DeepSeekV3Meta *meta, infiniDevice_t device, int ndev, @@ -445,7 +445,7 @@ createDeepSeekV3Weights(const DeepSeekV3Meta *meta, return weights; }; -__C DeepSeekV3WeightLoader * +__INFINI_C DeepSeekV3WeightLoader * createDeepSeekV3WeightLoader() { return &weight_loader; } diff --git a/src/models/jiuge/jiuge.cpp b/src/models/jiuge/jiuge.cpp index 41f8e5ea..8b65d8f4 100644 --- a/src/models/jiuge/jiuge.cpp +++ b/src/models/jiuge/jiuge.cpp @@ -315,7 +315,7 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, } } -__C void +__INFINI_C void inferBatchJiuge(struct JiugeModel *model, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, @@ -348,7 +348,7 @@ inferBatchJiuge(struct JiugeModel *model, } } -__C void +__INFINI_C void forwardBatchJiuge(struct JiugeModel *model, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, @@ -444,7 +444,7 @@ JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infi } } -__C struct JiugeModel * +__INFINI_C struct JiugeModel * createJiugeModel(const JiugeMeta *meta, const JiugeWeights *weights, infiniDevice_t device, @@ -456,7 +456,7 @@ createJiugeModel(const JiugeMeta *meta, return model; } -__C void destroyJiugeModel(struct JiugeModel *model) { +__INFINI_C void destroyJiugeModel(struct JiugeModel *model) { auto ndev = model->dev_resources.size(); for (size_t idev = 0; idev < ndev; idev++) { diff --git a/src/models/jiuge_awq/jiuge_awq.cpp b/src/models/jiuge_awq/jiuge_awq.cpp index 4452c400..46453de0 100644 --- a/src/models/jiuge_awq/jiuge_awq.cpp +++ b/src/models/jiuge_awq/jiuge_awq.cpp @@ -242,7 +242,7 @@ void inferDeviceBatch(const JiugeAWQMeta *meta, DeviceResource &rsrc, } } -__C void +__INFINI_C void inferBatchJiugeAWQ(struct JiugeAWQModel *model, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, @@ -275,7 +275,7 @@ inferBatchJiugeAWQ(struct JiugeAWQModel *model, } } -__C void +__INFINI_C void forwardBatchJiugeAWQ(struct JiugeAWQModel *model, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, @@ -372,14 +372,14 @@ JiugeAWQModel::JiugeAWQModel(const JiugeAWQMeta *meta, const ModelWeights *weigh } } -__C struct JiugeAWQModel * +__INFINI_C struct JiugeAWQModel * createJiugeAWQModel(const JiugeAWQMeta *meta, const ModelWeights *weights) { JiugeAWQModel *model = new JiugeAWQModel(meta, weights); return model; } -__C void destroyJiugeAWQModel(struct JiugeAWQModel *model) { +__INFINI_C void destroyJiugeAWQModel(struct JiugeAWQModel *model) { auto ndev = model->dev_resources.size(); for (size_t idev = 0; idev < ndev; idev++) { diff --git a/src/models/jiuge_awq/jiuge_awq_weight.cpp b/src/models/jiuge_awq/jiuge_awq_weight.cpp index b01735d0..9d3dbfde 100644 --- a/src/models/jiuge_awq/jiuge_awq_weight.cpp +++ b/src/models/jiuge_awq/jiuge_awq_weight.cpp @@ -118,7 +118,7 @@ JiugeAWQWeights::JiugeAWQWeights( #undef REGISTER_LAYER_QUANT_WEIGHT } -__C struct ModelWeights * +__INFINI_C struct ModelWeights * createJiugeAWQWeights(const JiugeAWQMeta *meta, infiniDevice_t device, int ndev, diff --git a/test/bench/test_benchmark.py b/test/bench/test_benchmark.py index 95653366..4b49105a 100644 --- a/test/bench/test_benchmark.py +++ b/test/bench/test_benchmark.py @@ -4,11 +4,6 @@ import re import csv import numpy as np -import infinicore -from infinilm.modeling_utils import load_model_state_dict_by_file -from infinilm.distributed import DistConfig -from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig -from infinilm.infer_engine import GenerationConfig, InferEngine from datasets import load_dataset, Dataset from abc import ABC, abstractmethod @@ -57,6 +52,11 @@ def __init__( enable_paged_attn=False, ): import transformers + import infinicore + from infinilm.modeling_utils import load_model_state_dict_by_file + from infinilm.distributed import DistConfig + from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig + from infinilm.infer_engine import InferEngine self.benchmark = benchmark @@ -73,6 +73,7 @@ def __init__( "iluvatar": "cuda", "kunlun": "cuda", "hygon": "cuda", + "ali": "cuda", } device_name = device_map.get(device_type_str.lower(), "cpu") @@ -102,7 +103,9 @@ def __init__( ) elif model_type in ["qwen2", "qwen3"]: # For qwen2/qwen3 models: no trust_remote_code (matches jiuge line 534-536) - self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path) + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_dir_path, trust_remote_code=True + ) else: # Default: use trust_remote_code=True for other models self.tokenizer = transformers.AutoTokenizer.from_pretrained( @@ -178,17 +181,26 @@ def _generate_step(self, tokens, max_steps, topp_, topk_, temperature_): which properly handles KV cache through GenerationMixin. """ # Convert tokens to infinicore format + import infinicore + from infinilm.infer_engine import GenerationConfig + input_ids_list = [tokens] input_ids = infinicore.from_list(input_ids_list) start_time = time.perf_counter() # For cpp backend, reset cache before generation if use_cache is enabled - if self.model.use_cache and hasattr(self.model, "_model") and hasattr(self.model._model, "reset_cache"): + if ( + self.model.use_cache + and hasattr(self.model, "_model") + and hasattr(self.model._model, "reset_cache") + ): batch_size = input_ids.shape[0] seq_len = input_ids.shape[1] max_cache_len = max_steps + seq_len - self.model.reset_cache(batch_size=batch_size, initial_capacity=max_cache_len) + self.model.reset_cache( + batch_size=batch_size, initial_capacity=max_cache_len + ) # Use model's built-in generate() method which properly handles KV cache # Pass sampling parameters (temperature, topk, topp) via kwargs @@ -363,6 +375,124 @@ def destroy_model_instance(self): print("Torch model destroyed") +class VLLMBenchmark(BaseBenchmark): + """vLLM backend using vllm.LLM""" + + def __init__( + self, + model_dir_path, + device_type_str="nvidia", + tensor_parallel_size=1, + benchmark="ceval", + ): + import transformers + from vllm import LLM + + if device_type_str == "cpu": + raise ValueError("vLLM backend does not support CPU device type.") + + self.benchmark = benchmark + + # ---- tokenizer ---- + with open(os.path.join(model_dir_path, "config.json"), "r") as f: + import json + + self.config_dict = json.load(f) + + model_type = self.config_dict.get("model_type", "") + if model_type in ["qwen2", "qwen3"]: + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_dir_path, trust_remote_code=True + ) + else: + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_dir_path, trust_remote_code=True + ) + + eos_token_id = self.config_dict.get("eos_token_id") + self.eos_token_id = ( + [eos_token_id] if isinstance(eos_token_id, int) else eos_token_id + ) + + # ---- vLLM engine ---- + print("Loading model with vLLM backend...") + self.llm = LLM( + model=model_dir_path, + tensor_parallel_size=tensor_parallel_size, + trust_remote_code=True, + ) + print("vLLM model loaded successfully") + + def max_context_len(self): + return self.config_dict.get("max_position_embeddings", 2048) + + def render_input_content(self, *args, **kwargs): + if self.benchmark == "ceval": + return render_ceval(self.tokenizer, *args, **kwargs) + elif self.benchmark == "mmlu": + return render_mmlu(self.tokenizer, *args, **kwargs) + else: + raise ValueError(f"Unknown benchmark: {self.benchmark}") + + def generate(self, *args, max_steps=500, topp_=1.0, topk_=1, temperature_=1.0): + input_content = self.render_input_content(*args) + print(input_content, end="", flush=True) + + tokens = self.encode_text(input_content) + return self._generate_step(tokens, max_steps, topp_, topk_, temperature_) + + def _generate_step(self, tokens, max_steps, topp_, topk_, temperature_): + from vllm import SamplingParams + + prompt = self.tokenizer.decode(tokens) + + sampling_params = SamplingParams( + max_tokens=max_steps, + temperature=temperature_, + top_p=topp_, + top_k=topk_, + stop_token_ids=self.eos_token_id, + ) + + start_time = time.perf_counter() + + outputs = self.llm.generate( + prompts=[prompt], + sampling_params=sampling_params, + ) + + end_time = time.perf_counter() + + # ---- post process ---- + output_text = outputs[0].outputs[0].text + + # ---- stats ---- + input_tokens = len(tokens) + new_tokens = len(self.encode_text(output_text)) + total_tokens = input_tokens + new_tokens + + total_time = end_time - start_time + throughput = total_tokens / total_time if total_time > 0 else 0.0 + + print(output_text) + print() + print(f"Total time: {total_time * 1000:.2f} ms") + print(f"Input tokens: {input_tokens}") + print(f"New tokens: {new_tokens}") + print(f"Total tokens processed: {total_tokens}") + print(f"Throughput: {throughput:.2f} tok/s") + + global TOTAL_TOKENS, TOTAL_TIME + TOTAL_TOKENS += total_tokens + TOTAL_TIME += total_time + + return output_text + + def destroy_model_instance(self): + del self.llm + print("vLLM model destroyed") + + def render_ceval(_tokenizer, conversation): """Render C-Eval conversation to input content""" return ( @@ -390,13 +520,16 @@ def render_mmlu(_tokenizer, question, choices): if hasattr(_tokenizer, "apply_chat_template"): conversation = [ {"role": "system", "content": instruction}, - {"role": "user", "content": f"{question}\n{choices_text}\nAnswer:"}, + {"role": "user", "content": f"{question}\n{choices_text}\n"}, ] try: - return _tokenizer.apply_chat_template( - conversation=conversation, - add_generation_prompt=True, - tokenize=False, + return ( + _tokenizer.apply_chat_template( + conversation=conversation, + add_generation_prompt=True, + tokenize=False, + ) + + "The answer is: " ) except Exception: return prompt @@ -656,7 +789,7 @@ def test(): # Parse arguments manually to handle device flags properly if len(sys.argv) < 4: print( - "Usage: python test_benchmark.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] --bench [ceval|mmlu] [--backend cpp|torch] [--ndev N] [--subject SUBJECT] [--split {test|val|all}] [--num_samples N] [--max_new_tokens N] [--output_csv PATH] [--cache_dir PATH]" + "Usage: python test_benchmark.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon | --ali] --bench [ceval|mmlu] [--backend cpp|torch|vllm] [--ndev N] [--subject SUBJECT] [--split {test|val|all}] [--num_samples N] [--max_new_tokens N] [--output_csv PATH] [--cache_dir PATH]" ) sys.exit(1) @@ -739,9 +872,11 @@ def test(): device_type_str = "kunlun" elif device_flag == "--hygon": device_type_str = "hygon" + elif device_flag == "--ali": + device_type_str = "ali" else: print( - "Usage: python test_benchmark.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] --bench [ceval|mmlu] [--backend cpp|torch] [--ndev N] [--subject SUBJECT] [--num_samples N] [--max_new_tokens N] [--output_csv PATH] [--cache_dir PATH]" + "Usage: python test_benchmark.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon | --ali] --bench [ceval|mmlu] [--backend cpp|torch|vllm] [--ndev N] [--subject SUBJECT] [--num_samples N] [--max_new_tokens N] [--output_csv PATH] [--cache_dir PATH]" ) sys.exit(1) @@ -764,7 +899,10 @@ def test(): # Create model based on backend (create once, reuse for all subjects) if backend == "torch": + assert ndev == 1, "Torch backend only supports single-device evaluation" model = TorchBenchmark(model_path, device_type_str, benchmark) + elif backend == "vllm": + model = VLLMBenchmark(model_path, device_type_str, ndev, benchmark) else: model = InfiniLMBenchmark( model_path, device_type_str, ndev, backend, benchmark, enable_paged_attn diff --git a/xmake.lua b/xmake.lua index aab1a0c7..c29875aa 100644 --- a/xmake.lua +++ b/xmake.lua @@ -8,6 +8,16 @@ set_toolchains("gcc") add_includedirs("third_party/spdlog/include") add_includedirs("third_party/json/single_include/") +option("use-kv-caching") + set_default(false) + set_showmenu(true) + set_description("Whether to compile the path using the kv caching operator") +option_end() + +if has_config("use-kv-caching") then + add_defines("ENABLE_KV_CACHING") +end + target("infinicore_infer") set_kind("shared")