[Platform] Add current_platform.num_compute_units interface (#35042)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
This commit is contained in:
Kunshang Ji
2026-02-25 14:22:49 +08:00
committed by GitHub
parent 92510edc32
commit 8ad54a991b
24 changed files with 72 additions and 52 deletions

View File

@@ -13,7 +13,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import get_cu_count
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
@@ -38,7 +38,7 @@ if current_platform.is_rocm():
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
def num_programs(total_tokens):
return min(total_tokens, get_cu_count())
return min(total_tokens, num_compute_units())
@triton.jit
def cp_mha_gather_cache_kernel(