[Platform] Add current_platform.num_compute_units interface (#35042)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
This commit is contained in:
Kunshang Ji
2026-02-25 14:22:49 +08:00
committed by GitHub
parent 92510edc32
commit 8ad54a991b
24 changed files with 72 additions and 52 deletions

View File

@@ -9,6 +9,7 @@ import torch
import vllm._custom_ops as ops
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.platform_utils import num_compute_units
def cal_diff(
@@ -124,8 +125,7 @@ def test_cutlass_mla_decode(
q_pe = q_pe_padded
kv_cache_flat = blocked_k.squeeze(2)
device_properties = torch.cuda.get_device_properties(torch.device("cuda:0"))
sm_count = device_properties.multi_processor_count
sm_count = num_compute_units(device.index)
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
max_seqlen * block_size, b, sm_count, num_kv_splits=1
)

View File

@@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.utils.allspark_utils import (
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.platform_utils import num_compute_units
def is_gptq_allspark_supported(min_capability: int, max_capability: int) -> bool:
@@ -78,7 +79,7 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
if has_zp:
zp = zp.to(dtype)
properties = torch.cuda.get_device_properties(qw.device.index)
sm_count = properties.multi_processor_count
sm_count = num_compute_units(qw.device.index)
sm_version = properties.major * 10 + properties.minor
n_32align = (n + 32 - 1) // 32 * 32

View File

@@ -9,7 +9,7 @@ import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.platforms import current_platform
from vllm.platforms.rocm import on_gfx950
from vllm.utils.platform_utils import get_cu_count
from vllm.utils.platform_utils import num_compute_units
DTYPES = [torch.bfloat16, torch.float16]
BIAS_MODES = [0, 1, 2]
@@ -121,7 +121,7 @@ def pad_fp8(weight):
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
torch.manual_seed(seed)
cu_count = get_cu_count()
cu_count = num_compute_units()
# Next ^2 of n
N_p2 = 1 << (n - 1).bit_length()
@@ -186,7 +186,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = get_cu_count()
cu_count = num_compute_units()
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
@@ -203,7 +203,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = get_cu_count()
cu_count = num_compute_units()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
@@ -222,7 +222,7 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = get_cu_count()
cu_count = num_compute_units()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
@@ -267,7 +267,7 @@ def test_rocm_wvsplitk_fp8_kernel(
ref_out = torch._scaled_mm(
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, get_cu_count(), BIAS)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, num_compute_units(), BIAS)
if xnorm:
torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8)

View File

@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.allspark_utils import (
check_allspark_supported_dtype_shape,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.utils.platform_utils import num_compute_units
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
@@ -45,7 +46,7 @@ class AllSparkLinearKernel(MPLinearKernel):
# prepare the parameters required for the kernel
properties = torch.cuda.get_device_properties(device.index)
sm_count = properties.multi_processor_count
sm_count = num_compute_units(device.index)
sm_version = properties.major * 10 + properties.minor
gemm_args = {}
gemm_args["sm_count"] = sm_count

View File

@@ -7,7 +7,7 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.platform_utils import get_cu_count
from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import direct_register_custom_op
from .ScaledMMLinearKernel import (
@@ -36,7 +36,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
out_dtype,
As,
Bs,
get_cu_count(),
num_compute_units(),
bias,
)
# Fallback

View File

@@ -9,6 +9,7 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_torch_equal_or_newer
from vllm.v1.attention.backends.registry import AttentionBackendEnum
@@ -147,7 +148,7 @@ def matmul_persistent(
assert bias is None or bias.dim() == 1, (
"Currently assuming bias is 1D, let Horace know if you run into this"
)
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
NUM_SMS = num_compute_units(a.device.index)
M, K = a.shape
K, N = b.shape
dtype = a.dtype

View File

@@ -13,8 +13,6 @@
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
from functools import lru_cache
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -22,6 +20,7 @@ from einops import rearrange
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv, next_power_of_2
from vllm.utils.platform_utils import num_compute_units
from .utils import input_guard
@@ -162,15 +161,8 @@ def layer_norm_fwd_kernel(
tl.store(Y_base, y, mask=mask)
@lru_cache
def _get_sm_count(device: torch.device) -> int:
"""Get and cache the SM count for a given device."""
props = torch.cuda.get_device_properties(device)
return props.multi_processor_count
def calc_rows_per_block(M: int, device: torch.device) -> int:
sm_count = _get_sm_count(device)
sm_count = num_compute_units(device.index)
rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count))
rows_per_block = min(rows_per_block, 4)
return rows_per_block

View File

@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils.platform_utils import num_compute_units
from .quant_utils import pack_cols, unpack_cols
@@ -271,7 +272,7 @@ def marlin_make_workspace_new(
) -> torch.Tensor:
# In the new marlin kernel, we use the num of threadblocks as workspace
# size. The num of threadblocks is sms_count * max_blocks_per_sm.
sms = torch.cuda.get_device_properties(device).multi_processor_count
sms = num_compute_units(device.index)
return torch.zeros(
sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False
)

View File

@@ -11,7 +11,7 @@ from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.platform_utils import get_cu_count
from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
@@ -149,7 +149,7 @@ def rocm_unquantized_gemm_impl(
m = weight.shape[0]
k = weight.shape[1]
cu_count = get_cu_count()
cu_count = num_compute_units()
if use_aiter_triton_gemm(n, m, k, x.dtype):
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
@@ -199,7 +199,7 @@ def rocm_unquantized_gemm_impl(
x_view = x.reshape(-1, x.size(-1))
if m > 8 and 0 < n <= 4:
cu_count = get_cu_count()
cu_count = num_compute_units()
out = ops.wvSplitK(weight, x_view, cu_count, bias)
return out.reshape(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:

View File

@@ -26,6 +26,7 @@ from vllm.utils.deep_gemm import (
m_grouped_fp8_gemm_nt_contiguous,
)
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import num_compute_units
def _generate_optimal_warmup_m_values(
@@ -44,7 +45,7 @@ def _generate_optimal_warmup_m_values(
# DeepGEMM's possible block sizes
block_ms = [64, 128, 256]
block_ns = list(range(16, min(257, n + 1), 16))
num_sms = torch.cuda.get_device_properties(device).multi_processor_count
num_sms = num_compute_units(device.index)
m_values = set()

View File

@@ -538,6 +538,10 @@ class CudaPlatformBase(Platform):
def support_static_graph_mode(cls) -> bool:
return True
@classmethod
def num_compute_units(cls, device_id=0):
return torch.cuda.get_device_properties(device_id).multi_processor_count
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

View File

@@ -692,6 +692,16 @@ class Platform:
"""
return {}
@classmethod
def num_compute_units(cls, device_id: int = 0) -> int:
"""
Get the number of compute units for the current platform.
(NVIDIA SM / AMD CU / Intel EU)
"""
raise NotImplementedError(
"num_compute_units is not implemented for the current platform."
)
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED

View File

@@ -682,3 +682,7 @@ class RocmPlatform(Platform):
@classmethod
def support_static_graph_mode(cls) -> bool:
return True
@classmethod
def num_compute_units(cls, device_id=0):
return torch.cuda.get_device_properties(device_id).multi_processor_count

View File

@@ -277,3 +277,7 @@ class XPUPlatform(Platform):
"""Copy blocks from XPU to host (CPU)."""
_src_cache = src_cache[:, src_block_indices]
dst_cache[:, dst_block_indices] = _src_cache.cpu()
@classmethod
def num_compute_units(cls, device_id: int = 0) -> int:
return torch.xpu.get_device_properties(device_id).max_compute_units

View File

@@ -24,11 +24,6 @@ def xpu_is_initialized() -> bool:
return torch.xpu.is_initialized()
def get_cu_count(device_id: int = 0) -> int:
"""Returns the total number of compute units (CU) on single GPU."""
return torch.cuda.get_device_properties(device_id).multi_processor_count
def cuda_get_device_properties(
device, names: Sequence[str], init_cuda=False
) -> tuple[Any, ...]:
@@ -57,3 +52,11 @@ def is_uva_available() -> bool:
# UVA requires pinned memory.
# TODO: Add more requirements for UVA if needed.
return is_pin_memory_available()
@cache
def num_compute_units(device_id: int = 0) -> int:
"""Get the number of compute units of the current device."""
from vllm.platforms import current_platform
return current_platform.num_compute_units(device_id)

View File

@@ -16,6 +16,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonMetadataBuilder,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
@@ -74,8 +75,7 @@ class SM100Workspace:
# Pre-compute sm_count to avoid recomputing it. Use device 0 as a proxy
# (assumes all devices are similar)
properties = torch.cuda.get_device_properties(torch.device("cuda:0"))
self._sm_count = properties.multi_processor_count
self._sm_count = num_compute_units(0)
def get_buf(self):
return self._workspace_buf

View File

@@ -21,6 +21,7 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
@@ -130,8 +131,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self.cg_buf_num_splits = None
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
device_properties = torch.cuda.get_device_properties(self.device)
num_sms = device_properties.multi_processor_count
num_sms = num_compute_units(self.device.index)
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.cg_buf_tile_scheduler_metadata = torch.zeros(

View File

@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
@@ -237,8 +238,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
# DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
props = torch.cuda.get_device_properties(device)
sm_count = props.multi_processor_count
sm_count = num_compute_units(device.index)
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config)

View File

@@ -9,6 +9,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
@@ -219,8 +220,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
)
self.reorder_batch_threshold += self.num_speculative_tokens
props = torch.cuda.get_device_properties(self.device)
sm_count = props.multi_processor_count
sm_count = num_compute_units(self.device.index)
self.num_sms = sm_count
self.decode_lens_buffer = torch.empty(

View File

@@ -13,7 +13,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import get_cu_count
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
@@ -38,7 +38,7 @@ if current_platform.is_rocm():
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
def num_programs(total_tokens):
return min(total_tokens, get_cu_count())
return min(total_tokens, num_compute_units())
@triton.jit
def cp_mha_gather_cache_kernel(

View File

@@ -13,6 +13,7 @@ import torch
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import next_power_of_2
from vllm.utils.platform_utils import num_compute_units
_TRITON_TABLE_CACHE: dict[tuple[torch.device], tuple[torch.Tensor, torch.Tensor]] = {}
_TRITON_BUFFER_CACHE: dict[tuple[torch.device, torch.dtype, int], torch.Tensor] = {}
@@ -988,7 +989,7 @@ def apply_top_k_top_p_triton(
else:
p_ptr = logits # Dummy pointer (won't be read)
num_sm = torch.cuda.get_device_properties(logits.device).multi_processor_count
num_sm = num_compute_units(logits.device.index)
NUM_PROGRAMS = min(num_sm, batch_size)
# Cache per-Triton Program buffer on each device.

View File

@@ -98,7 +98,7 @@ from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.nvtx_pytorch_hooks import PytHooks
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units
from vllm.utils.torch_utils import (
get_dtype_size,
kv_cache_dtype_str_to_dtype,
@@ -909,8 +909,8 @@ class GPUModelRunner(
# Note: used for model runner override.
def _init_device_properties(self) -> None:
"""Initialize attributes from torch.cuda.get_device_properties"""
self.device_properties = torch.cuda.get_device_properties(self.device)
self.num_sms = self.device_properties.multi_processor_count
self.num_sms = num_compute_units(self.device.index)
# Note: used for model runner override.
def _sync_device(self) -> None:

View File

@@ -23,6 +23,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
logger = init_logger(__name__)
@@ -72,8 +73,7 @@ class SMControlContextManager:
"SM control is currently only supported on CUDA"
)
props = torch.cuda.get_device_properties(torch.cuda.current_device())
total_sms = props.multi_processor_count
total_sms = num_compute_units(torch.cuda.current_device().index)
assert comm_sms < total_sms
self.total_sms = total_sms

View File

@@ -28,9 +28,6 @@ class XPUModelRunner(GPUModelRunner):
# FIXME: To be verified.
self.cascade_attn_enabled = False
def _init_device_properties(self) -> None:
self.num_sms = None
def _sync_device(self) -> None:
torch.xpu.synchronize()