fix: cuda.CUstream import

This commit is contained in:
2026-05-23 06:40:05 +00:00
parent 39311133d6
commit d5273b7f4f

View File

@@ -4,6 +4,7 @@ Uses @cute.jit to construct MMA objects inside a compiled context,
where OperandMajorMode values are valid MLIR operands.
"""
import torch, math, sys
import cuda.bindings.driver as cuda
import cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16, Int32
@@ -97,7 +98,7 @@ def probe_hd(hd):
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
stream = cutlass.cuda.CUstream(torch.cuda.current_stream().cuda_stream)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
probe = BudgetProbe(head_dim=hd)
print(f' Compiling hd={hd}...', flush=True)