diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index bbd3a27d..08692619 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -95,10 +95,19 @@ def _prepack_weight_sf(weight_sf, N, K, tag): Returns a tensor of shape (E, sfb_size) with SFB already in CUTLASS interleaved layout, skipping the per-call remap+memset+alloc. """ - cache_attr = f"_prepacked_{tag}" - cached = getattr(_prepack_weight_sf, cache_attr, None) - if cached is not None: - return cached + cache_key = ( + tag, + weight_sf.data_ptr(), + tuple(weight_sf.shape), + str(weight_sf.dtype), + weight_sf.device.index, + N, + K, + ) + if not hasattr(_prepack_weight_sf, '_cache'): + _prepack_weight_sf._cache = {} + if cache_key in _prepack_weight_sf._cache: + return _prepack_weight_sf._cache[cache_key] from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import prepack_sfb @@ -112,7 +121,7 @@ def _prepack_weight_sf(weight_sf, N, K, tag): packed.append(prepack_sfb(weight_sf[e], M_for_layout, N, K)) packed = torch.stack(packed, dim=0).contiguous() - setattr(_prepack_weight_sf, cache_attr, packed) + _prepack_weight_sf._cache[cache_key] = packed if MEGA_MOE_DEBUG: print(f"[PREPACK] {tag}: E={E} N={N} K={K} packed_shape={packed.shape} "