[Platform] Add current_platform.num_compute_units interface (#35042)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
This commit is contained in:
@@ -9,7 +9,7 @@ import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.rocm import on_gfx950
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
BIAS_MODES = [0, 1, 2]
|
||||
@@ -121,7 +121,7 @@ def pad_fp8(weight):
|
||||
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
|
||||
def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = get_cu_count()
|
||||
cu_count = num_compute_units()
|
||||
|
||||
# Next ^2 of n
|
||||
N_p2 = 1 << (n - 1).bit_length()
|
||||
@@ -186,7 +186,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = get_cu_count()
|
||||
cu_count = num_compute_units()
|
||||
|
||||
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
|
||||
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
|
||||
@@ -203,7 +203,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = get_cu_count()
|
||||
cu_count = num_compute_units()
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||
@@ -222,7 +222,7 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = get_cu_count()
|
||||
cu_count = num_compute_units()
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||
@@ -267,7 +267,7 @@ def test_rocm_wvsplitk_fp8_kernel(
|
||||
ref_out = torch._scaled_mm(
|
||||
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
|
||||
)
|
||||
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, get_cu_count(), BIAS)
|
||||
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, num_compute_units(), BIAS)
|
||||
|
||||
if xnorm:
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8)
|
||||
|
||||
Reference in New Issue
Block a user