[Kernel] Factor out epilogues from cutlass kernels (#5391)

Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: zifeitong <zifei.tong@parasail.io>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
This commit is contained in:
Tyler Michael Smith
2024-06-13 14:22:19 -04:00
committed by GitHub
parent 0ce7b952f8
commit 85657b5607
12 changed files with 274 additions and 232 deletions

View File

@@ -47,7 +47,7 @@ def cutlass_fp8_gemm_helper(m: int,
scale_b = (torch.randn(
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype)
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(out_dtype)
@@ -74,7 +74,7 @@ def cutlass_int8_gemm_helper(m: int,
scale_b = (torch.randn(
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype)
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b *
b.to(dtype=torch.float32)).to(dtype=out_dtype)
@@ -180,11 +180,11 @@ def test_cutlass_subset():
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
out = ops.cutlass_scaled_mm_dq(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
out = ops.cutlass_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b *
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)
@@ -203,8 +203,8 @@ class CutlassLayer(torch.nn.Module):
self.out_dtype = out_dtype
def forward(self, a):
return ops.cutlass_scaled_mm_dq(a, self.b, self.scale_a, self.scale_b,
self.out_dtype)
return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
self.out_dtype)
@pytest.mark.parametrize("per_act_token", [True, False])