[Kernel] Update cutlass_scaled_mm to support 2d group (blockwise) scaling (#11868)

This commit is contained in:
Lucas Wilkinson
2025-01-30 21:33:00 -05:00
committed by GitHub
parent 4078052f09
commit 9798b2fb00
25 changed files with 1924 additions and 346 deletions

View File

@@ -10,6 +10,7 @@ import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import cdiv
from .utils import baseline_scaled_mm, to_fp8, to_int8
@@ -39,6 +40,11 @@ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
# -1 means full extent in that dimension
TENSORWISE_GROUP_SHAPE = (-1, -1)
PER_TOKEN_GROUP_SHAPE = (1, -1)
PER_OUT_CH_GROUP_SHAPE = (-1, 1)
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
@@ -47,11 +53,22 @@ def rand_int8(shape: tuple, device: str = "cuda"):
return to_int8(torch.rand(shape, device=device) * 255 - 128)
def group_scale_helper(shape, group_shape):
return [shape[i] if s < 0 else s for i, s in enumerate(group_shape)]
def scale_shape(shape, group_shape):
assert len(shape) == len(group_shape)
group_shape = group_scale_helper(shape, group_shape)
return tuple(
cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
def cutlass_fp8_gemm_helper(m: int,
n: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
a_scale_group_shape: tuple,
b_scale_group_shape: tuple,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
@@ -60,13 +77,17 @@ def cutlass_fp8_gemm_helper(m: int,
a = to_fp8(torch.randn((m, k), device=device))
b = to_fp8(torch.randn((n, k), device=device).t())
m_a_scales = m if per_token_act_quant else 1
n_b_scales = n if per_out_channel_weight_quant else 1
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
# make scales M-major for blockwise quant, doesn't affect 1D scales
scale_a = scale_a.t().contiguous().t()
# make scales K-major for blockwise quant, doesn't affect 1D scales
scale_b = scale_b.t().contiguous().t()
scale_a = (torch.randn((m_a_scales, 1), device=device,
dtype=torch.float32))
scale_b = (torch.randn((1, n_b_scales), device=device,
dtype=torch.float32))
if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
else:
@@ -84,8 +105,8 @@ def cutlass_fp8_gemm_helper(m: int,
def cutlass_int8_gemm_helper(m: int,
n: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
a_scale_group_shape: tuple,
b_scale_group_shape: tuple,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
@@ -94,13 +115,11 @@ def cutlass_int8_gemm_helper(m: int,
a = to_int8(torch.randn((m, k), device=device) * 5)
b = to_int8(torch.randn((n, k), device=device).t() * 5)
m_a_scales = m if per_token_act_quant else 1
n_b_scales = n if per_out_channel_weight_quant else 1
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
scale_a = (torch.randn((m_a_scales, 1), device=device,
dtype=torch.float32))
scale_b = (torch.randn((1, n_b_scales), device=device,
dtype=torch.float32))
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
@@ -117,82 +136,135 @@ def cutlass_int8_gemm_helper(m: int,
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool, use_bias: bool):
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
def test_cutlass_fp8_gemm(m: int, n: int, k: int, a_scale_group_shape,
b_scale_group_shape, use_bias: bool):
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
use_bias)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
[((1, 128), (128, 128))])
@pytest.mark.parametrize("use_bias", [False])
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="FP8 blockwise is not supported on this GPU type.")
def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
a_scale_group_shape,
b_scale_group_shape, use_bias: bool):
if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0:
return
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
return
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
use_bias)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool, use_bias: bool):
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
b_scale_group_shape, use_bias: bool):
cutlass_int8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
use_bias)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: Type[torch.dtype],
use_bias: bool):
cutlass_int8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=out_dtype)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: Type[torch.dtype],
use_bias: bool):
cutlass_fp8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=out_dtype)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
[((1, 128), (128, 128))])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [False])
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="FP8 blockwise is not supported on this GPU type.")
def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: Type[torch.dtype],
use_bias: bool):
cutlass_fp8_gemm_helper(512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=out_dtype)
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
def test_cutlass_fp8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
use_bias: bool, device: str):
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
torch.bfloat16, device)
cutlass_fp8_gemm_helper(512, 512, 512, a_scale_group_shape,
b_scale_group_shape, use_bias, torch.bfloat16,
device)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
use_bias: bool, device: str):
cutlass_int8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=torch.bfloat16,
device=device)
@@ -203,28 +275,32 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
# of a large power of two. In any case, the kernel will have a naive fallback
# when N and K are not divisible by 16. But M is the number of tokens and the
# kernel must handle any M thrown at it.
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
def test_cutlass_fp8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
use_bias: bool):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
use_bias)
cutlass_fp8_gemm_helper(m, nk, nk, a_scale_group_shape,
b_scale_group_shape, use_bias)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
use_bias: bool):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
use_bias)
cutlass_int8_gemm_helper(m, nk, nk, a_scale_group_shape,
b_scale_group_shape, use_bias)
@pytest.mark.parametrize("m", [32, 64, 128])