[Bugfix] Add replacement of _compute_slot_mapping_kernel on CPU (#37987)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@@ -3,7 +3,6 @@ depends_on: []
|
||||
steps:
|
||||
- label: CPU-Kernel Tests
|
||||
depends_on: []
|
||||
soft_fail: true
|
||||
device: intel_cpu
|
||||
no_plugin: true
|
||||
source_file_dependencies:
|
||||
@@ -23,7 +22,6 @@ steps:
|
||||
|
||||
- label: CPU-Compatibility Tests
|
||||
depends_on: []
|
||||
soft_fail: true
|
||||
device: intel_cpu
|
||||
no_plugin: true
|
||||
source_file_dependencies:
|
||||
@@ -37,7 +35,6 @@ steps:
|
||||
|
||||
- label: CPU-Language Generation and Pooling Model Tests
|
||||
depends_on: []
|
||||
soft_fail: true
|
||||
device: intel_cpu
|
||||
no_plugin: true
|
||||
source_file_dependencies:
|
||||
@@ -53,7 +50,6 @@ steps:
|
||||
|
||||
- label: CPU-Quantization Model Tests
|
||||
depends_on: []
|
||||
soft_fail: true
|
||||
device: intel_cpu
|
||||
no_plugin: true
|
||||
source_file_dependencies:
|
||||
@@ -73,7 +69,6 @@ steps:
|
||||
|
||||
- label: CPU-Distributed Tests
|
||||
depends_on: []
|
||||
soft_fail: true
|
||||
device: intel_cpu
|
||||
no_plugin: true
|
||||
source_file_dependencies:
|
||||
@@ -92,7 +87,6 @@ steps:
|
||||
|
||||
- label: CPU-Multi-Modal Model Tests %N
|
||||
depends_on: []
|
||||
soft_fail: true
|
||||
device: intel_cpu
|
||||
no_plugin: true
|
||||
source_file_dependencies:
|
||||
@@ -107,7 +101,6 @@ steps:
|
||||
|
||||
- label: "Arm CPU Test"
|
||||
depends_on: []
|
||||
soft_fail: true
|
||||
device: arm_cpu
|
||||
no_plugin: true
|
||||
commands:
|
||||
|
||||
@@ -126,6 +126,12 @@ void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
|
||||
const torch::Tensor& topk_id, const bool skip_weighted,
|
||||
const std::string& act, const std::string& isa);
|
||||
|
||||
void compute_slot_mapping_kernel_impl(const torch::Tensor query_start_loc,
|
||||
const torch::Tensor positions,
|
||||
const torch::Tensor block_table,
|
||||
torch::Tensor slot_mapping,
|
||||
const int64_t block_size);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@@ -334,6 +340,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor! out, Tensor query, Tensor kv_cache,"
|
||||
" float scale, Tensor block_tables, Tensor seq_lens) -> ()");
|
||||
ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
|
||||
|
||||
ops.def(
|
||||
"compute_slot_mapping_kernel_impl(Tensor query_start_loc, Tensor "
|
||||
"positions, Tensor block_table, Tensor(a3!) slot_mapping, SymInt "
|
||||
"block_size) -> ()",
|
||||
&compute_slot_mapping_kernel_impl);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
||||
@@ -189,3 +189,38 @@ ScratchPadManager* ScratchPadManager::get_scratchpad_manager() {
|
||||
return &manager;
|
||||
}
|
||||
} // namespace cpu_utils
|
||||
|
||||
void compute_slot_mapping_kernel_impl(const torch::Tensor query_start_loc,
|
||||
const torch::Tensor positions,
|
||||
const torch::Tensor block_table,
|
||||
torch::Tensor slot_mapping,
|
||||
const int64_t block_size) {
|
||||
const int32_t req_num = query_start_loc.size(0) - 1;
|
||||
const int64_t block_table_stride = block_table.stride(0);
|
||||
|
||||
const int32_t* __restrict__ query_start_loc_ptr =
|
||||
query_start_loc.data_ptr<int32_t>();
|
||||
const int64_t* __restrict__ positions_ptr = positions.data_ptr<int64_t>();
|
||||
const int32_t* __restrict__ blocktable_ptr = block_table.data_ptr<int32_t>();
|
||||
int64_t* __restrict__ slot_mapping_ptr = slot_mapping.data_ptr<int64_t>();
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int32_t req_idx = 0; req_idx < req_num; ++req_idx) {
|
||||
int32_t token_start_idx = query_start_loc_ptr[req_idx];
|
||||
int32_t token_end_idx = query_start_loc_ptr[req_idx + 1];
|
||||
int32_t token_num = token_end_idx - token_start_idx;
|
||||
const int64_t* __restrict__ curr_position_ptr =
|
||||
positions_ptr + token_start_idx;
|
||||
int64_t* __restrict__ curr_slot_mapping_ptr =
|
||||
slot_mapping_ptr + token_start_idx;
|
||||
const int32_t* __restrict__ curr_block_table_ptr =
|
||||
blocktable_ptr + req_idx * block_table_stride;
|
||||
|
||||
for (int32_t token_idx = 0; token_idx < token_num; ++token_idx) {
|
||||
int64_t token_position = curr_position_ptr[token_idx];
|
||||
int64_t block_id = curr_block_table_ptr[token_position / block_size];
|
||||
curr_slot_mapping_ptr[token_idx] =
|
||||
block_id * block_size + token_position % block_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ RUN ln -s /usr/bin/clangd-14 /usr/bin/clangd
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install -e tests/vllm_test_utils
|
||||
uv pip install --no-build-isolation -e tests/vllm_test_utils
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=cache,target=/root/.cache/ccache \
|
||||
|
||||
@@ -309,7 +309,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
group_size = quant_config.get("group_size")
|
||||
zero_point = quant_config.get("zero_point")
|
||||
|
||||
if not (current_platform.is_cuda_alike() or current_platform.is_cpu()):
|
||||
if not current_platform.is_cuda_alike():
|
||||
return False
|
||||
|
||||
if quant_method != "awq":
|
||||
|
||||
47
vllm/utils/cpu_triton_utils.py
Normal file
47
vllm/utils/cpu_triton_utils.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Contains replacement functions to fallback Triton usages in CPU backend
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class _FuncWrapper:
|
||||
def __init__(self, func: Callable) -> None:
|
||||
self.func = func
|
||||
|
||||
def __getitem__(self, *args, **kwargs) -> Callable:
|
||||
return self.func
|
||||
|
||||
|
||||
# For _compute_slot_mapping_kernel in vllm/v1/worker/block_table.py
|
||||
def _compute_slot_mapping_kernel_impl(
|
||||
num_tokens: int,
|
||||
max_num_tokens: int,
|
||||
query_start_loc: torch.Tensor, # [num_reqs + 1], int32
|
||||
positions: torch.Tensor, # [num_tokens], int64
|
||||
block_table: torch.Tensor, # [max_num_reqs, max_num_blocks_per_req], int32
|
||||
block_table_stride: int, # max_num_blocks_per_req
|
||||
block_size: int,
|
||||
slot_mapping: torch.Tensor, # [max_num_tokens], int64
|
||||
TOTAL_CP_WORLD_SIZE: int,
|
||||
TOTAL_CP_RANK: int,
|
||||
CP_KV_CACHE_INTERLEAVE_SIZE: int,
|
||||
PAD_ID: int,
|
||||
BLOCK_SIZE: int,
|
||||
) -> None:
|
||||
assert TOTAL_CP_WORLD_SIZE == 1, "Context Parallelism is not supported on CPU."
|
||||
torch.ops._C.compute_slot_mapping_kernel_impl(
|
||||
query_start_loc,
|
||||
positions,
|
||||
block_table,
|
||||
slot_mapping,
|
||||
block_size,
|
||||
)
|
||||
|
||||
|
||||
compute_slot_mapping_kernel = _FuncWrapper(_compute_slot_mapping_kernel_impl)
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.utils.cpu_triton_utils as cpu_tl
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@@ -28,6 +29,7 @@ class CPUModelRunner(GPUModelRunner):
|
||||
self.cascade_attn_enabled = False
|
||||
|
||||
self._postprocess_tensors()
|
||||
self._postprocess_triton()
|
||||
|
||||
def _postprocess_tensors(self) -> None:
|
||||
# Note: replace device tensors with cpu tensors
|
||||
@@ -52,6 +54,13 @@ class CPUModelRunner(GPUModelRunner):
|
||||
if isinstance(v, CpuGpuBuffer):
|
||||
v.gpu = v.cpu
|
||||
|
||||
def _postprocess_triton(self) -> None:
|
||||
import vllm.v1.worker.block_table
|
||||
|
||||
vllm.v1.worker.block_table._compute_slot_mapping_kernel = (
|
||||
cpu_tl.compute_slot_mapping_kernel
|
||||
)
|
||||
|
||||
@instrument(span_name="Loading (CPU)")
|
||||
def load_model(self, load_dummy_weights: bool = False) -> None:
|
||||
if load_dummy_weights:
|
||||
|
||||
Reference in New Issue
Block a user