diff --git a/tests/unit/test_d2_headpacked.py b/tests/unit/test_d2_headpacked.py index cdb57559..cd5d8346 100644 --- a/tests/unit/test_d2_headpacked.py +++ b/tests/unit/test_d2_headpacked.py @@ -46,7 +46,7 @@ def test_d2_headpacked_n1(): fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True) o = torch.zeros(T, hd, dtype=torch.bfloat16, device='cuda') - stream = cuda.cuStream(0) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) @@ -78,7 +78,7 @@ def test_d2_headpacked_basic(): fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True, num_query_heads=n_h) o = torch.zeros(n_h * T, hd, dtype=torch.bfloat16, device='cuda') - stream = cuda.cuStream(0) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) @@ -116,7 +116,7 @@ def test_d2_headpacked_flash(): fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True, num_query_heads=n_h) o_padded = torch.zeros(128, hd, dtype=torch.bfloat16, device='cuda') - stream = cuda.cuStream(0) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) q_c = ct.from_dlpack(q_padded).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_padded)) k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) @@ -153,7 +153,7 @@ def test_d2_headpacked_hd128(): fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True, num_query_heads=n_h) o = torch.zeros(n_h * T, hd, dtype=torch.bfloat16, device='cuda') - stream = cuda.cuStream(0) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))