Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_cutlass_scaled_mm.py`.
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
@@ -36,9 +37,7 @@ MNK_FACTORS = [
|
||||
(512, 24576, 128),
|
||||
]
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
# -1 means full extent in that dimension
|
||||
TENSORWISE_GROUP_SHAPE = (-1, -1)
|
||||
@@ -60,18 +59,19 @@ def group_scale_helper(shape, group_shape):
|
||||
def scale_shape(shape, group_shape):
|
||||
assert len(shape) == len(group_shape)
|
||||
group_shape = group_scale_helper(shape, group_shape)
|
||||
return tuple(
|
||||
cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
|
||||
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
|
||||
|
||||
|
||||
def cutlass_fp8_gemm_helper(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
use_bias: bool,
|
||||
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda"):
|
||||
def cutlass_fp8_gemm_helper(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
use_bias: bool,
|
||||
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda",
|
||||
):
|
||||
# Test for a cutlass kernel with per-token activation quantization
|
||||
# and per-output channel weight quantization.
|
||||
a = to_fp8(torch.randn((m, k), device=device))
|
||||
@@ -80,8 +80,8 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
|
||||
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
|
||||
|
||||
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
|
||||
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
|
||||
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
|
||||
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
|
||||
|
||||
# make scales M-major for blockwise quant, doesn't affect 1D scales
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
@@ -89,7 +89,7 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
scale_b = scale_b.t().contiguous().t()
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
||||
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
|
||||
else:
|
||||
bias = None
|
||||
|
||||
@@ -98,18 +98,19 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=5e-1, atol=1.5e-1)
|
||||
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm,
|
||||
(out, a, b, scale_a, scale_b, bias))
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias))
|
||||
|
||||
|
||||
def cutlass_int8_gemm_helper(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
use_bias: bool,
|
||||
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda"):
|
||||
def cutlass_int8_gemm_helper(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
use_bias: bool,
|
||||
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda",
|
||||
):
|
||||
# Test for a cutlass kernel with per-token activation quantization
|
||||
# and per-output channel weight quantization.
|
||||
a = to_int8(torch.randn((m, k), device=device) * 5)
|
||||
@@ -118,11 +119,11 @@ def cutlass_int8_gemm_helper(m: int,
|
||||
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
|
||||
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
|
||||
|
||||
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
|
||||
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
|
||||
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
|
||||
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
||||
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
|
||||
else:
|
||||
bias = None
|
||||
|
||||
@@ -131,145 +132,192 @@ def cutlass_int8_gemm_helper(m: int,
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm,
|
||||
(out, a, b, scale_a, scale_b, bias))
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm(m: int, n: int, k: int, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_gemm(
|
||||
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
|
||||
[((1, 128), (128, 128))])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="FP8 blockwise is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias: bool):
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="FP8 blockwise is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_blockwise_scale_gemm(
|
||||
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0:
|
||||
return
|
||||
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
|
||||
return
|
||||
if m % 4 != 0 and current_platform.has_device_capability(100):
|
||||
return
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias)
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias: bool):
|
||||
cutlass_int8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias)
|
||||
def test_cutlass_int8_gemm(
|
||||
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
cutlass_int8_gemm_helper(
|
||||
m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool):
|
||||
cutlass_int8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype)
|
||||
def test_cutlass_int8_gemm_output_dtype(
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool,
|
||||
):
|
||||
cutlass_int8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_gemm_output_dtype(
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool,
|
||||
):
|
||||
cutlass_fp8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
|
||||
[((1, 128), (128, 128))])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))]
|
||||
)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="FP8 blockwise is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="FP8 blockwise is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_blockwise_scale_gemm_dtype(
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool,
|
||||
):
|
||||
cutlass_fp8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias: bool, device: str):
|
||||
cutlass_fp8_gemm_helper(512, 512, 512, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias, torch.bfloat16,
|
||||
device)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_gemm_devices(
|
||||
a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str
|
||||
):
|
||||
cutlass_fp8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
torch.bfloat16,
|
||||
device,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias: bool, device: str):
|
||||
cutlass_int8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=torch.bfloat16,
|
||||
device=device)
|
||||
def test_cutlass_int8_gemm_devices(
|
||||
a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str
|
||||
):
|
||||
cutlass_int8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
# For the following two tests:
|
||||
@@ -277,32 +325,42 @@ def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
|
||||
# of a large power of two. In any case, the kernel will have a naive fallback
|
||||
# when N and K are not divisible by 16. But M is the number of tokens and the
|
||||
# kernel must handle any M thrown at it.
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias: bool):
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_gemm_m_sweep(
|
||||
a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
for nk in range(32, 128, 32):
|
||||
for m in range(1, 128):
|
||||
cutlass_fp8_gemm_helper(m, nk, nk, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias)
|
||||
cutlass_fp8_gemm_helper(
|
||||
m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias: bool):
|
||||
def test_cutlass_int8_gemm_m_sweep(
|
||||
a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
for nk in range(32, 128, 32):
|
||||
for m in range(1, 128):
|
||||
cutlass_int8_gemm_helper(m, nk, nk, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias)
|
||||
cutlass_int8_gemm_helper(
|
||||
m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [32, 64, 128])
|
||||
@@ -310,8 +368,7 @@ def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
|
||||
@pytest.mark.parametrize("k", [64, 128, 256])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.skip
|
||||
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
|
||||
out_dtype: torch.dtype):
|
||||
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, out_dtype: torch.dtype):
|
||||
# Currently, the test is failing because folding azp into
|
||||
# 16-bit bias loses too much precision
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
@@ -328,7 +385,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
|
||||
|
||||
b_dq = scale_b * bq_f32
|
||||
|
||||
azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_a = torch.rand((1,), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
|
||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||
|
||||
@@ -340,18 +397,17 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
|
||||
J = torch.ones((1, k), device="cuda", dtype=torch.float32)
|
||||
azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
|
||||
assert azp_bias.shape == (1, n)
|
||||
assert azp_bias[0, :].shape == (n, )
|
||||
assert azp_bias[0, :].shape == (n,)
|
||||
|
||||
baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
|
||||
(aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
|
||||
dtype=out_dtype, device='cuda')
|
||||
baseline_q = (
|
||||
scale_a.to(device="cpu")
|
||||
* scale_b.to(device="cpu")
|
||||
* ((aq_i32 + azp_aq_i8).to(device="cpu") @ bq_i32.to(device="cpu"))
|
||||
).to(dtype=out_dtype, device="cuda")
|
||||
|
||||
out = ops.cutlass_scaled_mm(aq_i8,
|
||||
bq_i8,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=out_dtype,
|
||||
bias=azp_bias[0, :])
|
||||
out = ops.cutlass_scaled_mm(
|
||||
aq_i8, bq_i8, scale_a, scale_b, out_dtype=out_dtype, bias=azp_bias[0, :]
|
||||
)
|
||||
torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
|
||||
torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
|
||||
|
||||
@@ -362,8 +418,9 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("azp_per_token", [True, False])
|
||||
def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
||||
use_bias: bool, azp_per_token: bool):
|
||||
def test_cutlass_int8_azp(
|
||||
m: int, n: int, k: int, out_dtype: torch.dtype, use_bias: bool, azp_per_token: bool
|
||||
):
|
||||
m_azp = m if azp_per_token else 1
|
||||
scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
|
||||
@@ -377,16 +434,12 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
||||
bq_f32 = bq_i8.to(dtype=torch.float32)
|
||||
b_dq = scale_b * bq_f32
|
||||
|
||||
azp_a = torch.rand(
|
||||
(m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_a = torch.rand((m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
|
||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||
|
||||
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
|
||||
torch.testing.assert_close(a_dq,
|
||||
scale_a * aq_f32 - azp_a,
|
||||
rtol=1e-4,
|
||||
atol=1e-3)
|
||||
torch.testing.assert_close(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3)
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
|
||||
@@ -396,8 +449,8 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
||||
baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
|
||||
|
||||
# int32 mm not supported on CUDA
|
||||
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu')
|
||||
cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda')
|
||||
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device="cpu")
|
||||
cq = (a_noazp_i32_cpu @ bq_i32.to(device="cpu")).to(device="cuda")
|
||||
baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
|
||||
|
||||
# Hadamard is just the sum of the cols
|
||||
@@ -406,14 +459,14 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
||||
func_bias = bias if use_bias else None
|
||||
|
||||
if azp_per_token:
|
||||
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
|
||||
out_dtype, azp_adj_i32, azp_i32,
|
||||
func_bias)
|
||||
out = ops.cutlass_scaled_mm_azp(
|
||||
aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_adj_i32, azp_i32, func_bias
|
||||
)
|
||||
else:
|
||||
azp_with_adj_i32 = azp_i32 * azp_adj_i32
|
||||
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
|
||||
out_dtype, azp_with_adj_i32, None,
|
||||
func_bias)
|
||||
out = ops.cutlass_scaled_mm_azp(
|
||||
aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_with_adj_i32, None, func_bias
|
||||
)
|
||||
|
||||
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
|
||||
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
|
||||
@@ -423,13 +476,15 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
||||
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
|
||||
|
||||
if azp_per_token:
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
|
||||
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
|
||||
func_bias))
|
||||
opcheck(
|
||||
torch.ops._C.cutlass_scaled_mm_azp,
|
||||
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32, func_bias),
|
||||
)
|
||||
else:
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
|
||||
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
|
||||
func_bias))
|
||||
opcheck(
|
||||
torch.ops._C.cutlass_scaled_mm_azp,
|
||||
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None, func_bias),
|
||||
)
|
||||
|
||||
|
||||
# Test working with a subset of A and B
|
||||
@@ -445,23 +500,14 @@ def test_cutlass_subset():
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
out = ops.cutlass_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
# Test to make sure cuda graphs work
|
||||
class CutlassLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, b, scale_a, scale_b, out_dtype):
|
||||
super().__init__()
|
||||
self.b = b
|
||||
@@ -470,8 +516,9 @@ class CutlassLayer(torch.nn.Module):
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
def forward(self, a):
|
||||
return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
|
||||
self.out_dtype)
|
||||
return ops.cutlass_scaled_mm(
|
||||
a, self.b, self.scale_a, self.scale_b, self.out_dtype
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@@ -485,10 +532,8 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
|
||||
m_a_scales = m if per_act_token else 1
|
||||
n_b_scales = n if per_out_ch else 1
|
||||
|
||||
scale_a = (torch.randn(
|
||||
(m_a_scales, 1), device="cuda", dtype=torch.float32) / 10)
|
||||
scale_b = (torch.randn(
|
||||
(1, n_b_scales), device="cuda", dtype=torch.float32) / 10)
|
||||
scale_a = torch.randn((m_a_scales, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, n_b_scales), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
# Construct a trivial model with a single layer that calls a CUTLASS kernel
|
||||
model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)
|
||||
@@ -502,13 +547,14 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
|
||||
out.zero_()
|
||||
g.replay()
|
||||
|
||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
|
||||
baseline = torch.mm(
|
||||
scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32)
|
||||
).to(torch.bfloat16)
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
def test_cutlass_support_opcheck():
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, ))
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability,))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_experts", [8, 64])
|
||||
@@ -517,11 +563,13 @@ def test_cutlass_support_opcheck():
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
|
||||
per_out_ch: bool, use_bias: bool):
|
||||
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_group_gemm(
|
||||
num_experts: int, per_act_token: bool, per_out_ch: bool, use_bias: bool
|
||||
):
|
||||
# Device and dtype setup
|
||||
device = "cuda"
|
||||
out_dtype = torch.half
|
||||
@@ -533,13 +581,9 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
|
||||
b_scales_tensors = []
|
||||
baseline_tensors = []
|
||||
|
||||
expert_offsets = torch.zeros((num_experts + 1),
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int64)
|
||||
|
||||
problem_sizes = torch.zeros((num_experts, 3),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32)
|
||||
|
||||
if not per_act_token:
|
||||
one_scale_a = torch.randn((1, 1), device=device, dtype=torch.float32)
|
||||
@@ -566,75 +610,76 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
|
||||
b_tensors.append(b_g)
|
||||
|
||||
# Set up A/B scales
|
||||
scale_b = torch.randn((1, n_b_scales),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
scale_b = torch.randn((1, n_b_scales), device=device, dtype=torch.float32)
|
||||
b_scales_tensors.append(scale_b)
|
||||
|
||||
if per_act_token:
|
||||
scale_a = torch.randn((m_a_scales, 1),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
scale_a = torch.randn((m_a_scales, 1), device=device, dtype=torch.float32)
|
||||
a_scales_tensors.append(scale_a)
|
||||
else:
|
||||
scale_a = one_scale_a
|
||||
|
||||
# Compute baseline result for this group
|
||||
baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype,
|
||||
None)
|
||||
baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, None)
|
||||
baseline_tensors.append(baseline_g)
|
||||
|
||||
a_tensors_stacked = torch.empty((expert_offsets[num_experts], k_g),
|
||||
device=device,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
b_tensors_stacked = torch.empty((num_experts, n_g, k_g),
|
||||
device=device,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
a_tensors_stacked = torch.empty(
|
||||
(expert_offsets[num_experts], k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
b_tensors_stacked = torch.empty(
|
||||
(num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
for g in range(num_experts):
|
||||
a_tensors_stacked[expert_offsets[g]:expert_offsets[g +
|
||||
1]] = a_tensors[g]
|
||||
a_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g]
|
||||
b_tensors_stacked[g] = b_tensors[g].t()
|
||||
b_tensors_stacked = b_tensors_stacked.transpose(1, 2)
|
||||
|
||||
if per_act_token:
|
||||
a_scales_tensors_stacked = torch.empty(
|
||||
(expert_offsets[num_experts], 1),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
(expert_offsets[num_experts], 1), device=device, dtype=torch.float32
|
||||
)
|
||||
for g in range(num_experts):
|
||||
a_scales_tensors_stacked[
|
||||
expert_offsets[g]:expert_offsets[g + 1]] = a_scales_tensors[g]
|
||||
a_scales_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = (
|
||||
a_scales_tensors[g]
|
||||
)
|
||||
else:
|
||||
a_scales_tensors_stacked = one_scale_a
|
||||
|
||||
b_scales_tensors_stacked = torch.empty((num_experts, n_b_scales),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
b_scales_tensors_stacked = torch.empty(
|
||||
(num_experts, n_b_scales), device=device, dtype=torch.float32
|
||||
)
|
||||
for g in range(num_experts):
|
||||
b_scales_tensors_stacked[g] = b_scales_tensors[g]
|
||||
|
||||
out_tensors_stacked = torch.zeros((expert_offsets[num_experts], n_g),
|
||||
device=device,
|
||||
dtype=out_dtype)
|
||||
out_tensors_stacked = torch.zeros(
|
||||
(expert_offsets[num_experts], n_g), device=device, dtype=out_dtype
|
||||
)
|
||||
|
||||
ab_strides = torch.full((num_experts, ),
|
||||
a_tensors_stacked.stride(0),
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides = torch.full((num_experts, ),
|
||||
out_tensors_stacked.stride(0),
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
ab_strides = torch.full(
|
||||
(num_experts,), a_tensors_stacked.stride(0), device="cuda", dtype=torch.int64
|
||||
)
|
||||
c_strides = torch.full(
|
||||
(num_experts,), out_tensors_stacked.stride(0), device="cuda", dtype=torch.int64
|
||||
)
|
||||
|
||||
ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked,
|
||||
b_tensors_stacked, a_scales_tensors_stacked,
|
||||
b_scales_tensors_stacked, expert_offsets[:-1],
|
||||
problem_sizes, ab_strides, ab_strides, c_strides,
|
||||
per_act_token, per_out_ch)
|
||||
ops.cutlass_moe_mm(
|
||||
out_tensors_stacked,
|
||||
a_tensors_stacked,
|
||||
b_tensors_stacked,
|
||||
a_scales_tensors_stacked,
|
||||
b_scales_tensors_stacked,
|
||||
expert_offsets[:-1],
|
||||
problem_sizes,
|
||||
ab_strides,
|
||||
ab_strides,
|
||||
c_strides,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
)
|
||||
|
||||
# Validate each group's result against the baseline
|
||||
for g in range(num_experts):
|
||||
baseline = baseline_tensors[g]
|
||||
c = out_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]]
|
||||
c = out_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]]
|
||||
torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4)
|
||||
|
||||
Reference in New Issue
Block a user