[Platform] Add current_platform.num_compute_units interface (#35042)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
This commit is contained in:
Kunshang Ji
2026-02-25 14:22:49 +08:00
committed by GitHub
parent 92510edc32
commit 8ad54a991b
24 changed files with 72 additions and 52 deletions

View File

@@ -9,6 +9,7 @@ import torch
import vllm._custom_ops as ops
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.platform_utils import num_compute_units
def cal_diff(
@@ -124,8 +125,7 @@ def test_cutlass_mla_decode(
q_pe = q_pe_padded
kv_cache_flat = blocked_k.squeeze(2)
device_properties = torch.cuda.get_device_properties(torch.device("cuda:0"))
sm_count = device_properties.multi_processor_count
sm_count = num_compute_units(device.index)
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
max_seqlen * block_size, b, sm_count, num_kv_splits=1
)