Speed up the kernels/quantization/ tests (#18669)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -13,8 +13,13 @@ from vllm.platforms import current_platform
|
||||
|
||||
device = "cuda"
|
||||
|
||||
triton_scaled_mm_module = importlib.import_module(
|
||||
"vllm.model_executor.layers.quantization.compressed_tensors."
|
||||
"triton_scaled_mm")
|
||||
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
||||
|
||||
def scaled_mm_torch(a: torch.Tensor,
|
||||
|
||||
def torch_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
@@ -101,21 +106,8 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
|
||||
if use_bias:
|
||||
bias = torch.rand((N, ), device=device, dtype=out_dtype)
|
||||
|
||||
triton_scaled_mm_module = importlib.import_module(
|
||||
"vllm.model_executor.layers.quantization.compressed_tensors."
|
||||
"triton_scaled_mm")
|
||||
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
||||
|
||||
c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
a_cpu = a.cpu()
|
||||
b_cpu = b.cpu()
|
||||
scale_a_cpu = scale_a.cpu()
|
||||
scale_b_cpu = scale_b.cpu()
|
||||
bias_cpu = None if bias is None else bias.cpu()
|
||||
c_actual = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
c_actual = scaled_mm_torch(a_cpu, b_cpu, scale_a_cpu, scale_b_cpu,
|
||||
out_dtype, bias_cpu)
|
||||
|
||||
c_check_cpu = c_check.cpu()
|
||||
torch.testing.assert_close(c_check_cpu, c_actual, rtol=1e-1, atol=1e-1)
|
||||
torch.testing.assert_close(c_check, c_actual, rtol=1e-1, atol=1e-1)
|
||||
|
||||
Reference in New Issue
Block a user