[Kernel] Factor out epilogues from cutlass kernels (#5391)
Co-authored-by: Michael Goin <michael@neuralmagic.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: zifeitong <zifei.tong@parasail.io> Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
0ce7b952f8
commit
85657b5607
@@ -47,7 +47,7 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
scale_b = (torch.randn(
|
||||
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
|
||||
|
||||
out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype)
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
|
||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b * b.to(dtype=torch.float32)).to(out_dtype)
|
||||
|
||||
@@ -74,7 +74,7 @@ def cutlass_int8_gemm_helper(m: int,
|
||||
scale_b = (torch.randn(
|
||||
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
|
||||
|
||||
out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype)
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
|
||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b *
|
||||
b.to(dtype=torch.float32)).to(dtype=out_dtype)
|
||||
@@ -180,11 +180,11 @@ def test_cutlass_subset():
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
out = ops.cutlass_scaled_mm_dq(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
out = ops.cutlass_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b *
|
||||
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)
|
||||
@@ -203,8 +203,8 @@ class CutlassLayer(torch.nn.Module):
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
def forward(self, a):
|
||||
return ops.cutlass_scaled_mm_dq(a, self.b, self.scale_a, self.scale_b,
|
||||
self.out_dtype)
|
||||
return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
|
||||
self.out_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
|
||||
Reference in New Issue
Block a user