[Kernel] Fixup for CUTLASS kernels in CUDA graphs (#4954)

Pass the CUDA stream into the CUTLASS GEMMs, to avoid future issues with CUDA graphs
This commit is contained in:
Tyler Michael Smith
2024-05-22 10:10:43 -04:00
committed by GitHub
parent c74c913bfb
commit 8674f9880e
3 changed files with 50 additions and 2 deletions

View File

@@ -190,3 +190,44 @@ def test_cutlass_subset():
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
# Test to make sure cuda graphs work
class CutlassLayer(torch.nn.Module):
def __init__(self, b, scale_a, scale_b, out_dtype):
super().__init__()
self.b = b
self.scale_a = scale_a
self.scale_b = scale_b
self.out_dtype = out_dtype
def forward(self, a):
return ops.cutlass_scaled_mm_dq(a, self.b, self.scale_a, self.scale_b,
self.out_dtype)
def test_cutlass_cuda_graph():
m, n, k = 512, 512, 512
a = to_int8(torch.randn((m, k), device="cuda"))
b = to_int8(torch.randn((n, k), device="cuda").t())
scale_a = (torch.randn((m, 1), device="cuda", dtype=torch.float32) / 10)
scale_b = (torch.randn((1, n), device="cuda", dtype=torch.float32) / 10)
# Construct a trivial model with a single layer that calls a CUTLASS kernel
model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)
# Run the model with a cuda graph
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out = model(a)
out.zero_()
g.replay()
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)