[Platform] Add current_platform.num_compute_units interface (#35042)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
This commit is contained in:
@@ -13,7 +13,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
@@ -38,7 +38,7 @@ if current_platform.is_rocm():
|
||||
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
|
||||
|
||||
def num_programs(total_tokens):
|
||||
return min(total_tokens, get_cu_count())
|
||||
return min(total_tokens, num_compute_units())
|
||||
|
||||
@triton.jit
|
||||
def cp_mha_gather_cache_kernel(
|
||||
|
||||
Reference in New Issue
Block a user