[Kernel][Bugfix] Refactor and Fix CUTLASS 2:4 Sparse Kernels (#13198)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith
2025-02-13 19:01:14 -05:00
committed by GitHub
parent 2344192a55
commit c1e37bf71b
16 changed files with 576 additions and 473 deletions

View File

@@ -7,7 +7,6 @@ from typing import Tuple, Type
import pytest
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@@ -55,11 +54,39 @@ def prune_to_2_4(tensor):
return pruned.reshape(original_shape)
# This function checks that applying an identity matrix multiplication
# to the compressed weights yields the original uncompressed weights.
def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
b_compressed: torch.Tensor,
b_metadata: torch.Tensor):
# For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the
# same dtype as its inputs. This line addresses that constraint while
# arbitrarily using bfloat16 for the int8/fp8 cases.
out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16
eye = torch.eye(b.shape[0], device='cuda', dtype=dtype)
eye_scale = torch.ones(1, device='cuda', dtype=torch.float32)
b_decomp = ops.cutlass_scaled_sparse_mm(eye,
b_compressed,
b_metadata,
eye_scale,
eye_scale,
out_dtype=out_dtype)
torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp)
def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5
a = torch.randn((m, k), device='cuda')
b = torch.randn((n, k), device='cuda').t()
if dtype == torch.int8:
# ensure A and B aren't all zeros after rounding
a = a * 5.0
b = b * 5.0
b = prune_to_2_4(b.t()).t()
@@ -75,6 +102,7 @@ def make_rand_sparse_tensors(
raise ValueError("unsupported dtype")
b_compressed, e = ops.cutlass_sparse_compress(b.t())
check_compress_decompress_invariance(dtype, b, b_compressed, e)
# Compressed B, Metadata, Original A, B
return b_compressed, e, a, b
@@ -134,27 +162,37 @@ MNK_FACTORS = [
# Test working with a subset of A and B for sparse matmul
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype]):
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype],
use_bias: bool):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
bias = torch.rand((n, ), device="cuda", dtype=dtype) if use_bias else None
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=dtype)
baseline = F.linear(a, b.T)
out_dtype=dtype,
bias=bias)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1e-2)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=dtype,
bias=bias)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
@pytest.mark.skipif(not sparse_cutlass_supported(),
@@ -162,27 +200,34 @@ def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype]):
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int):
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
out_dtype = torch.bfloat16
bias = torch.rand(
(n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
out_dtype=out_dtype,
bias=bias)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
out_dtype=out_dtype,
bias=bias)
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
@pytest.mark.skipif(not sparse_cutlass_supported(),
@@ -198,18 +243,24 @@ def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
out_dtype = torch.bfloat16
bias = torch.rand(
(n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
out_dtype=out_dtype,
bias=bias)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
out_dtype=out_dtype,
bias=bias)
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)