fix: use CUstream instead of cuStream(0)

This commit is contained in:
2026-05-25 16:51:52 +00:00
parent 4826fa6afb
commit af136eee27

View File

@@ -46,7 +46,7 @@ def test_d2_headpacked_n1():
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True)
o = torch.zeros(T, hd, dtype=torch.bfloat16, device='cuda')
stream = cuda.cuStream(0)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
@@ -78,7 +78,7 @@ def test_d2_headpacked_basic():
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True, num_query_heads=n_h)
o = torch.zeros(n_h * T, hd, dtype=torch.bfloat16, device='cuda')
stream = cuda.cuStream(0)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
@@ -116,7 +116,7 @@ def test_d2_headpacked_flash():
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True, num_query_heads=n_h)
o_padded = torch.zeros(128, hd, dtype=torch.bfloat16, device='cuda')
stream = cuda.cuStream(0)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
q_c = ct.from_dlpack(q_padded).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_padded))
k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
@@ -153,7 +153,7 @@ def test_d2_headpacked_hd128():
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True, num_query_heads=n_h)
o = torch.zeros(n_h * T, hd, dtype=torch.bfloat16, device='cuda')
stream = cuda.cuStream(0)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))