fix: cuda.CUstream import

This commit is contained in:
2026-05-23 06:40:05 +00:00
parent 1c20b826d9
commit de439bcd75

View File

@@ -4,6 +4,7 @@ Uses @cute.jit to construct MMA objects inside a compiled context,
where OperandMajorMode values are valid MLIR operands.
"""
import torch, math, sys
import cuda.bindings.driver as cuda
import cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16, Int32
@@ -97,7 +98,7 @@ def probe_hd(hd):
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
stream = cutlass.cuda.CUstream(torch.cuda.current_stream().cuda_stream)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
probe = BudgetProbe(head_dim=hd)
print(f' Compiling hd={hd}...', flush=True)