diff --git a/tests/unit/test_d1_hd512_only.py b/tests/unit/test_d1_hd512_only.py index 62599cdb..6ef04c1f 100644 --- a/tests/unit/test_d1_hd512_only.py +++ b/tests/unit/test_d1_hd512_only.py @@ -45,7 +45,11 @@ def test(): import time t0 = time.time() - compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE) + from cutlass.base_dsl.compiler import CompileOptions, PtxasOptions, OptLevel + # PtxasOptions -j64: use 64 threads for ptxas register allocation (B200 has 256 cores) + # OptLevel(0): skip MLIR optimizations for faster compilation (first verify correctness, then optimize) + compile_opts = CompileOptions((PtxasOptions("-j64"), OptLevel(0))) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE, config=compile_opts) t1 = time.time() print(f'Compilation took {t1-t0:.1f}s', flush=True)