fix: prepack cache key includes data_ptr, shape, dtype, device, N, K
Old cache used only tag ('l1'/'l2'), so layer 1 would reuse layer 0's
packed scales if the function object persisted. Now keyed by
(tag, data_ptr, shape, dtype, device, N, K) — safe across layers.
This commit is contained in:
@@ -95,10 +95,19 @@ def _prepack_weight_sf(weight_sf, N, K, tag):
|
||||
Returns a tensor of shape (E, sfb_size) with SFB already in CUTLASS
|
||||
interleaved layout, skipping the per-call remap+memset+alloc.
|
||||
"""
|
||||
cache_attr = f"_prepacked_{tag}"
|
||||
cached = getattr(_prepack_weight_sf, cache_attr, None)
|
||||
if cached is not None:
|
||||
return cached
|
||||
cache_key = (
|
||||
tag,
|
||||
weight_sf.data_ptr(),
|
||||
tuple(weight_sf.shape),
|
||||
str(weight_sf.dtype),
|
||||
weight_sf.device.index,
|
||||
N,
|
||||
K,
|
||||
)
|
||||
if not hasattr(_prepack_weight_sf, '_cache'):
|
||||
_prepack_weight_sf._cache = {}
|
||||
if cache_key in _prepack_weight_sf._cache:
|
||||
return _prepack_weight_sf._cache[cache_key]
|
||||
|
||||
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import prepack_sfb
|
||||
|
||||
@@ -112,7 +121,7 @@ def _prepack_weight_sf(weight_sf, N, K, tag):
|
||||
packed.append(prepack_sfb(weight_sf[e], M_for_layout, N, K))
|
||||
|
||||
packed = torch.stack(packed, dim=0).contiguous()
|
||||
setattr(_prepack_weight_sf, cache_attr, packed)
|
||||
_prepack_weight_sf._cache[cache_key] = packed
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[PREPACK] {tag}: E={E} N={N} K={K} packed_shape={packed.shape} "
|
||||
|
||||
Reference in New Issue
Block a user