From 3cc00b12df1103c2e4558a2e1ca80e1faec2e269 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 15 May 2026 10:03:37 +0000 Subject: [PATCH] fix: prepack cache key includes data_ptr, shape, dtype, device, N, K MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Old cache used only tag ('l1'/'l2'), so layer 1 would reuse layer 0's packed scales if the function object persisted. Now keyed by (tag, data_ptr, shape, dtype, device, N, K) — safe across layers. --- src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index bbd3a27d..08692619 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -95,10 +95,19 @@ def _prepack_weight_sf(weight_sf, N, K, tag): Returns a tensor of shape (E, sfb_size) with SFB already in CUTLASS interleaved layout, skipping the per-call remap+memset+alloc. """ - cache_attr = f"_prepacked_{tag}" - cached = getattr(_prepack_weight_sf, cache_attr, None) - if cached is not None: - return cached + cache_key = ( + tag, + weight_sf.data_ptr(), + tuple(weight_sf.shape), + str(weight_sf.dtype), + weight_sf.device.index, + N, + K, + ) + if not hasattr(_prepack_weight_sf, '_cache'): + _prepack_weight_sf._cache = {} + if cache_key in _prepack_weight_sf._cache: + return _prepack_weight_sf._cache[cache_key] from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import prepack_sfb @@ -112,7 +121,7 @@ def _prepack_weight_sf(weight_sf, N, K, tag): packed.append(prepack_sfb(weight_sf[e], M_for_layout, N, K)) packed = torch.stack(packed, dim=0).contiguous() - setattr(_prepack_weight_sf, cache_attr, packed) + _prepack_weight_sf._cache[cache_key] = packed if MEGA_MOE_DEBUG: print(f"[PREPACK] {tag}: E={E} N={N} K={K} packed_shape={packed.shape} "