[Bugfix] Add replacement of _compute_slot_mapping_kernel on CPU (#37987)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@@ -126,6 +126,12 @@ void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
|
||||
const torch::Tensor& topk_id, const bool skip_weighted,
|
||||
const std::string& act, const std::string& isa);
|
||||
|
||||
void compute_slot_mapping_kernel_impl(const torch::Tensor query_start_loc,
|
||||
const torch::Tensor positions,
|
||||
const torch::Tensor block_table,
|
||||
torch::Tensor slot_mapping,
|
||||
const int64_t block_size);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@@ -334,6 +340,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor! out, Tensor query, Tensor kv_cache,"
|
||||
" float scale, Tensor block_tables, Tensor seq_lens) -> ()");
|
||||
ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
|
||||
|
||||
ops.def(
|
||||
"compute_slot_mapping_kernel_impl(Tensor query_start_loc, Tensor "
|
||||
"positions, Tensor block_table, Tensor(a3!) slot_mapping, SymInt "
|
||||
"block_size) -> ()",
|
||||
&compute_slot_mapping_kernel_impl);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
||||
Reference in New Issue
Block a user