fix: gather on slot_token presence, add shape asserts L1→L2
- Remove torch.equal heuristic — just gather when slot_token is provided - Add asserts for slot mapping shapes (ndim, numel == num_slots) - Add post-L1 and pre-L2 shape asserts (l1_slots, activated, l1_fp4, l1_sf_out)
This commit is contained in:
@@ -94,10 +94,9 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
print(f"[cutlass_grouped_gemm] slots={num_slots} K={K} N={N} "
|
||||
f"experts={num_experts} sfb_prepacked={sfb_prepacked}")
|
||||
|
||||
# Gather input rows by slot_token — needed when x_fp4 has token rows
|
||||
# but we need slot rows (L1). When slot_token is the identity (L2),
|
||||
# x_fp4 already has slot rows and the gather is a no-op.
|
||||
if slot_token is not None and not torch.equal(slot_token, torch.arange(num_slots, device=x_fp4.device)):
|
||||
# Gather input rows by slot_token when provided (L1: tokens→slots).
|
||||
# L2 doesn't pass slot_token, so no gather needed.
|
||||
if slot_token is not None:
|
||||
slot_x = x_fp4[slot_token]
|
||||
slot_x_sf = x_sf[slot_token]
|
||||
else:
|
||||
|
||||
@@ -369,6 +369,14 @@ def nvfp4_mega_moe_full(
|
||||
# Ensure alpha is a plain Python float (C extension can't handle torch scalars)
|
||||
l1_alpha = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale
|
||||
|
||||
# Shape consistency asserts — catch mismatched slot mappings early
|
||||
assert slot_expert_local.ndim == 1
|
||||
assert slot_token.ndim == 1
|
||||
assert slot_weight.ndim == 1
|
||||
assert slot_expert_local.numel() == num_slots
|
||||
assert slot_token.numel() == num_slots
|
||||
assert slot_weight.numel() == num_slots
|
||||
|
||||
# Prepack SFB weight scales into CUTLASS layout (lazy, once per layer)
|
||||
l1_N = l1_w.shape[2]
|
||||
l1_K = l1_w.shape[1] * 2
|
||||
@@ -386,6 +394,9 @@ def nvfp4_mega_moe_full(
|
||||
sfb_prepacked=True,
|
||||
) # (num_slots, 2*INTER) bfloat16
|
||||
|
||||
# Post-L1 shape asserts
|
||||
assert l1_slots.shape[0] == num_slots
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[L1-out] nan={torch.isnan(l1_slots).any().item()} "
|
||||
f"abs_max={l1_slots.abs().max().item():.4e}")
|
||||
@@ -402,6 +413,11 @@ def nvfp4_mega_moe_full(
|
||||
|
||||
# Step 4: Quantize activated slots → FP4
|
||||
l1_fp4, l1_sf_out, l2_global_scale = stage_activation(activated)
|
||||
|
||||
# Pre-L2 shape asserts
|
||||
assert activated.shape[0] == num_slots
|
||||
assert l1_fp4.shape[0] == num_slots
|
||||
assert l1_sf_out.shape[0] == num_slots
|
||||
l2_alpha = float(l2_global_scale) if not isinstance(l2_global_scale, float) else l2_global_scale
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
|
||||
Reference in New Issue
Block a user