[Platform] Add current_platform.num_compute_units interface (#35042)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
This commit is contained in:
@@ -9,6 +9,7 @@ import torch
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
|
||||
|
||||
def cal_diff(
|
||||
@@ -124,8 +125,7 @@ def test_cutlass_mla_decode(
|
||||
q_pe = q_pe_padded
|
||||
|
||||
kv_cache_flat = blocked_k.squeeze(2)
|
||||
device_properties = torch.cuda.get_device_properties(torch.device("cuda:0"))
|
||||
sm_count = device_properties.multi_processor_count
|
||||
sm_count = num_compute_units(device.index)
|
||||
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
|
||||
max_seqlen * block_size, b, sm_count, num_kv_splits=1
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user