diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index 26ca280d..0845fc0a 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -354,8 +354,9 @@ def run_nvfp4_grouped_gemm( K_packed = mat_a.shape[1] N_packed = mat_b.shape[2] cache_key = (num_experts, str(device), mma_tiler_mn, cluster_shape_mn, K_packed, N_packed) + use_cache = True # TEMP: set False to always recompile (debug) - if False and cache_key not in _compiled_kernel_cache: # TEMP: always recompile + if use_cache and cache_key in _compiled_kernel_cache: kernel = ScaledGroupedGemmKernel( scenario="2Dx3D", sf_vec_size=SF_VEC_SIZE,