fix: cuda.CUstream import
This commit is contained in:
@@ -4,6 +4,7 @@ Uses @cute.jit to construct MMA objects inside a compiled context,
|
||||
where OperandMajorMode values are valid MLIR operands.
|
||||
"""
|
||||
import torch, math, sys
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass, cutlass.cute as cute, cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32
|
||||
@@ -97,7 +98,7 @@ def probe_hd(hd):
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
stream = cutlass.cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
probe = BudgetProbe(head_dim=hd)
|
||||
print(f' Compiling hd={hd}...', flush=True)
|
||||
|
||||
Reference in New Issue
Block a user