diff --git a/tests/unit/test_d1_qk512.py b/tests/unit/test_d1_qk512.py index 2fa4f249..344167df 100644 --- a/tests/unit/test_d1_qk512.py +++ b/tests/unit/test_d1_qk512.py @@ -1,6 +1,6 @@ """Minimal hd=512 test: ONLY QK GEMM, no softmax, no PV. Goal: isolate whether the compilation hang is from QK or softmax/PV.""" -import torch, math, time +import torch, math, time, cutlass import cutlass.cute as cute import cutlass.torch as ct import cuda.bindings.driver as cuda