fix: prepack cache key includes data_ptr, shape, dtype, device, N, K

Old cache used only tag ('l1'/'l2'), so layer 1 would reuse layer 0's
packed scales if the function object persisted. Now keyed by
(tag, data_ptr, shape, dtype, device, N, K) — safe across layers.
This commit is contained in:
2026-05-15 10:03:37 +00:00
parent 3ba41b9322
commit 3cc00b12df

View File

@@ -95,10 +95,19 @@ def _prepack_weight_sf(weight_sf, N, K, tag):
Returns a tensor of shape (E, sfb_size) with SFB already in CUTLASS
interleaved layout, skipping the per-call remap+memset+alloc.
"""
cache_attr = f"_prepacked_{tag}"
cached = getattr(_prepack_weight_sf, cache_attr, None)
if cached is not None:
return cached
cache_key = (
tag,
weight_sf.data_ptr(),
tuple(weight_sf.shape),
str(weight_sf.dtype),
weight_sf.device.index,
N,
K,
)
if not hasattr(_prepack_weight_sf, '_cache'):
_prepack_weight_sf._cache = {}
if cache_key in _prepack_weight_sf._cache:
return _prepack_weight_sf._cache[cache_key]
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import prepack_sfb
@@ -112,7 +121,7 @@ def _prepack_weight_sf(weight_sf, N, K, tag):
packed.append(prepack_sfb(weight_sf[e], M_for_layout, N, K))
packed = torch.stack(packed, dim=0).contiguous()
setattr(_prepack_weight_sf, cache_attr, packed)
_prepack_weight_sf._cache[cache_key] = packed
if MEGA_MOE_DEBUG:
print(f"[PREPACK] {tag}: E={E} N={N} K={K} packed_shape={packed.shape} "