[Kernel] Add per-tensor and per-token AZP epilogues (#5941)
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -28,13 +28,16 @@ def to_int8(tensor: torch.Tensor):
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
def rand_int8(shape: tuple, device: str = "cuda"):
|
||||
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
||||
|
||||
|
||||
def baseline_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
output = (scale_a * (scale_b * (torch.mm(
|
||||
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
|
||||
if bias is not None:
|
||||
@@ -221,6 +224,121 @@ def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
|
||||
use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [32, 64, 128])
|
||||
@pytest.mark.parametrize("n", [16, 32, 64])
|
||||
@pytest.mark.parametrize("k", [64, 128, 256])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.skip
|
||||
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
|
||||
out_dtype: torch.dtype):
|
||||
# Currently, the test is failing because folding azp into
|
||||
# 16-bit bias loses too much precision
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
aq_i8 = rand_int8((m, k))
|
||||
bq_i8 = rand_int8((n, k)).t()
|
||||
|
||||
aq_i32 = aq_i8.to(dtype=torch.int32)
|
||||
bq_i32 = bq_i8.to(dtype=torch.int32)
|
||||
|
||||
aq_f32 = aq_i8.to(dtype=torch.float32)
|
||||
bq_f32 = bq_i8.to(dtype=torch.float32)
|
||||
|
||||
b_dq = scale_b * bq_f32
|
||||
|
||||
azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
|
||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||
|
||||
a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
|
||||
assert torch.allclose(a_dq, scale_a * aq_f32 + azp_a)
|
||||
|
||||
baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
|
||||
|
||||
J = torch.ones((1, k), device="cuda", dtype=torch.float32)
|
||||
azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
|
||||
assert azp_bias.shape == (1, n)
|
||||
assert azp_bias[0, :].shape == (n, )
|
||||
|
||||
baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
|
||||
(aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
|
||||
dtype=out_dtype, device='cuda')
|
||||
|
||||
out = ops.cutlass_scaled_mm(aq_i8,
|
||||
bq_i8,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=out_dtype,
|
||||
bias=azp_bias[0, :])
|
||||
assert torch.allclose(out, baseline_dq, rtol=1e-2, atol=1e0)
|
||||
assert torch.allclose(out, baseline_q, rtol=1e-2, atol=1e0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [32, 64, 128])
|
||||
@pytest.mark.parametrize("n", [16, 32, 64])
|
||||
@pytest.mark.parametrize("k", [64, 128, 256])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("azp_per_token", [True, False])
|
||||
def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
||||
use_bias: bool, azp_per_token: bool):
|
||||
m_azp = m if azp_per_token else 1
|
||||
scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
aq_i8 = rand_int8((m, k))
|
||||
aq_i32 = aq_i8.to(dtype=torch.int32)
|
||||
aq_f32 = aq_i8.to(dtype=torch.float32)
|
||||
|
||||
bq_i8 = rand_int8((n, k)).t()
|
||||
bq_i32 = bq_i8.to(dtype=torch.int32)
|
||||
bq_f32 = bq_i8.to(dtype=torch.float32)
|
||||
b_dq = scale_b * bq_f32
|
||||
|
||||
azp_a = torch.rand(
|
||||
(m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
|
||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||
|
||||
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
|
||||
assert torch.allclose(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3)
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
|
||||
else:
|
||||
bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
|
||||
|
||||
baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
|
||||
|
||||
# int32 mm not supported on CUDA
|
||||
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu')
|
||||
cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda')
|
||||
baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
|
||||
|
||||
# Hadamard is just the sum of the cols
|
||||
azp_adj_i32 = bq_i32.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
azp_i32 = azp_aq_i8.to(dtype=torch.int32)
|
||||
func_bias = bias if use_bias else None
|
||||
|
||||
if azp_per_token:
|
||||
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
|
||||
out_dtype, azp_adj_i32, azp_i32,
|
||||
func_bias)
|
||||
else:
|
||||
azp_with_adj_i32 = azp_i32 * azp_adj_i32
|
||||
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
|
||||
out_dtype, azp_with_adj_i32, None,
|
||||
func_bias)
|
||||
|
||||
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
|
||||
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
|
||||
rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
|
||||
atol = 1e-3
|
||||
assert torch.allclose(out, baseline_dq, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, baseline_q, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
# Test working with a subset of A and B
|
||||
def test_cutlass_subset():
|
||||
big_m, big_n, big_k = 1024, 1024, 1024
|
||||
|
||||
Reference in New Issue
Block a user