fix: use CUstream instead of cuStream(0)
This commit is contained in:
@@ -46,7 +46,7 @@ def test_d2_headpacked_n1():
|
||||
|
||||
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True)
|
||||
o = torch.zeros(T, hd, dtype=torch.bfloat16, device='cuda')
|
||||
stream = cuda.cuStream(0)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
@@ -78,7 +78,7 @@ def test_d2_headpacked_basic():
|
||||
|
||||
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True, num_query_heads=n_h)
|
||||
o = torch.zeros(n_h * T, hd, dtype=torch.bfloat16, device='cuda')
|
||||
stream = cuda.cuStream(0)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
@@ -116,7 +116,7 @@ def test_d2_headpacked_flash():
|
||||
|
||||
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True, num_query_heads=n_h)
|
||||
o_padded = torch.zeros(128, hd, dtype=torch.bfloat16, device='cuda')
|
||||
stream = cuda.cuStream(0)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
q_c = ct.from_dlpack(q_padded).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_padded))
|
||||
k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
@@ -153,7 +153,7 @@ def test_d2_headpacked_hd128():
|
||||
|
||||
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True, num_query_heads=n_h)
|
||||
o = torch.zeros(n_h * T, hd, dtype=torch.bfloat16, device='cuda')
|
||||
stream = cuda.cuStream(0)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
|
||||
Reference in New Issue
Block a user