[Platform] Add current_platform.num_compute_units interface (#35042)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
This commit is contained in:
Kunshang Ji
2026-02-25 14:22:49 +08:00
committed by GitHub
parent 92510edc32
commit 8ad54a991b
24 changed files with 72 additions and 52 deletions

View File

@@ -9,7 +9,7 @@ import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.platforms import current_platform
from vllm.platforms.rocm import on_gfx950
from vllm.utils.platform_utils import get_cu_count
from vllm.utils.platform_utils import num_compute_units
DTYPES = [torch.bfloat16, torch.float16]
BIAS_MODES = [0, 1, 2]
@@ -121,7 +121,7 @@ def pad_fp8(weight):
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
torch.manual_seed(seed)
cu_count = get_cu_count()
cu_count = num_compute_units()
# Next ^2 of n
N_p2 = 1 << (n - 1).bit_length()
@@ -186,7 +186,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = get_cu_count()
cu_count = num_compute_units()
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
@@ -203,7 +203,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = get_cu_count()
cu_count = num_compute_units()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
@@ -222,7 +222,7 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = get_cu_count()
cu_count = num_compute_units()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
@@ -267,7 +267,7 @@ def test_rocm_wvsplitk_fp8_kernel(
ref_out = torch._scaled_mm(
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, get_cu_count(), BIAS)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, num_compute_units(), BIAS)
if xnorm:
torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8)