diff --git a/README.md b/README.md index a2f8f166..d5843cf4 100644 --- a/README.md +++ b/README.md @@ -70,14 +70,21 @@ The kernel uses a **slot representation** instead of collapsing expert outputs e The slot approach fixes both: SiLU+Mul happens per-slot, and routing weights are applied exactly once at the final `index_add_` scatter. -### Prepacked SFB (Weight Scale Factors) +### SFB (Weight Scale Factors) — Remapped Per-Call, NOT Cached -Weight scale factors (SFB) are pre-remapped into CUTLASS interleaved layout once at first forward pass (lazily cached). This eliminates per-GEMM: -- 1 `cudaMalloc`-ish allocation for SFB -- 1 `cudaMemsetAsync` for SFB padding -- 1 remap kernel launch for SFB +Weight scale factors (SFB) are remapped from row-major to CUTLASS interleaved layout on every GEMM call. This is a lightweight scatter kernel (~µs) and is NOT the bottleneck compared to the GEMM itself. -The `cutlass_nvfp4_gemm_run_prepacked_sfb` C entry point accepts the prepacked SFB pointer directly. Only SFA (activation scales) is remapped dynamically — those change every forward pass. +⚠️ **DO NOT ADD A PREPACK CACHE FOR SFB.** Previous attempts caused critical issues: + +| Problem | Impact | +|---------|--------| +| **OOM** | ~1.75 GiB per prepacked tensor × 61 MoE layers × 2 (L1+L2) = ~214 GiB — exceeds B200 capacity | +| **Peak memory 2×** | `torch.stack` held all expert tensors + final stack simultaneously before LRU eviction | +| **CUDA graph trap** | LRU eviction frees tensors that CUDA graphs still reference → use-after-free → silent corruption or crash | +| **M-dependent layout** | `prepack_sfb(M=128)` assumed SFB layout size is M-independent (never verified). If wrong, entire prepack is invalid | +| **Cross-layer cache collision** | Tag-based cache (`"l1"`/`"l2"`) returned layer N-1's data for layer N. Fixed with data_ptr key, but the cache itself was the root problem | + +The per-call remap costs microseconds. The cache cost was hours of debugging. Don't repeat this mistake. --- @@ -117,7 +124,7 @@ The `cutlass_nvfp4_gemm_run_prepacked_sfb` C entry point accepts the prepacked S 6. Profile run (warmup) └─ First forward pass to allocate KV cache, etc. └─ This is where the CUTLASS GEMM first executes - └─ SFB weight scales are prepacked into CUTLASS layout (lazy, cached) + └─ SFB weight scales remapped per-expert inside CUTLASS (no cache) 7. Ready to serve ``` @@ -129,7 +136,7 @@ The `cutlass_nvfp4_gemm_run_prepacked_sfb` C entry point accepts the prepacked S ``` nvfp4_megamoe_kernel/ ├── __init__.py # Public API exports -├── nvfp4_mega_moe.py # Main kernel: nvfp4_mega_moe_full, L1/L2, stage_activation, prepack +├── nvfp4_mega_moe.py # Main kernel: nvfp4_mega_moe_full, L1/L2, stage_activation ├── weight_transform.py # Weight prep: fold global scale, pack UE4M3 ├── symm_buffer.py # GPU buffer allocation for MoE dispatch │ @@ -150,9 +157,9 @@ nvfp4_megamoe_kernel/ |------|-------------|--------------| | `weight_transform.py` | Once at startup (weight loading) | Takes raw NVFP4 checkpoint weights, folds global scales into block scales. Returns scales as `float8_e4m3fn` (not packed uint32). Output: `((l1_w, l1_sf), (l2_w, l2_sf))` | | `symm_buffer.py` | Once at startup (buffer alloc) | Pre-allocates GPU tensors for activations, scales, routing data, and all-reduce. These persist across forward passes. | -| `nvfp4_mega_moe.py` | Every forward pass | Orchestrates the MoE: reads from symm buffer → build slot mapping → L1 GEMM → SiLU+Mul per-slot → re-quantize → L2 GEMM → final index_add_ scatter with routing weights. Contains `stage_activation` (BF16→FP4), `unpack_ue4m3_u32`, and `_prepack_weight_sf` (lazy SFB prepack). | -| `cutlass_nvfp4_gemm/kernel.py` | Every forward pass (called by nvfp4_mega_moe) | Slot-based per-expert loop: gather slots for each expert, call CUTLASS GEMM (with prepacked SFB), write results to slot buffer. No routing weights — caller handles scatter. | -| `cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu` | Every forward pass (CUDA kernel) | The actual CUTLASS kernel: native NVFP4 block-scaled GEMM + GPU-side SFA remap. SFB remap done once at prepack time. Two GEMM entry points: standard (both remap) and prepacked-sfb (SFA remap only). | +| `nvfp4_mega_moe.py` | Every forward pass | Orchestrates the MoE: reads from symm buffer → build slot mapping → L1 GEMM → SiLU+Mul per-slot → re-quantize → L2 GEMM → final index_add_ scatter with routing weights. Contains `stage_activation` (BF16→FP4) and `unpack_ue4m3_u32`. NO prepack cache — SFB remapped per-call inside CUTLASS. | +| `cutlass_nvfp4_gemm/kernel.py` | Every forward pass (called by nvfp4_mega_moe) | Slot-based per-expert loop: gather slots for each expert, call CUTLASS GEMM (SFB remapped inside C extension), write results to slot buffer. No routing weights — caller handles scatter. | +| `cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu` | Every forward pass (CUDA kernel) | The actual CUTLASS kernel: native NVFP4 block-scaled GEMM + GPU-side SFA and SFB remap. | | `cutlass_nvfp4_gemm/sf_layout.py` | Reference only | Documents the CUTLASS SfAtom layout. Not used at runtime (remap is in CUDA). | --- diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index d2be3bab..4365f3bd 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -170,8 +170,8 @@ int cutlass_nvfp4_gemm_run( using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; - int sfa_size = cute::size(layout_SFA); - int sfb_size = cute::size(layout_SFB); + int sfa_size = cute::cosize(layout_SFA); + int sfb_size = cute::cosize(layout_SFB); int K_sf = K / InputSFVectorSize; cutlass::device_memory::allocation sfa_cutlass(sfa_size); @@ -223,7 +223,7 @@ extern "C" int cutlass_nvfp4_sfb_size( ) { using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; LayoutSFB layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); - *out_size = cute::size(layout_SFB); + *out_size = cute::cosize(layout_SFB); return 0; } @@ -237,7 +237,7 @@ extern "C" int cutlass_nvfp4_prepack_sfb_run( LayoutSFB layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; - int sfb_size = cute::size(layout_SFB); + int sfb_size = cute::cosize(layout_SFB); int K_sf = K / InputSFVectorSize; cudaMemsetAsync(static_cast(SFB_cutlass_ptr), 0, sfb_size * sizeof(ElementSF), stream); @@ -278,7 +278,7 @@ extern "C" int cutlass_nvfp4_gemm_run_prepacked_sfb( using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; - int sfa_size = cute::size(layout_SFA); + int sfa_size = cute::cosize(layout_SFA); int K_sf = K / InputSFVectorSize; // Only remap SFA (activation scales) — SFB is prepacked diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py index b7aa774c..124a8205 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py @@ -57,16 +57,16 @@ def cutlass_grouped_nvfp4_gemm( x_fp4, # (num_slots_or_tokens, K_half) int8 packed E2M1 x_sf, # (num_slots_or_tokens, sf_k) float8_e4m3fn block scales weights, # (E_per_rank, K_half, N) int8 packed E2M1, column-major for CUTLASS - weight_sf, # (E_per_rank, sf_k, N) float8_e4m3fn, column-major — or prepacked (E_per_rank, sfb_size) if sfb_prepacked=True + weight_sf, # (E_per_rank, sf_k, N) float8_e4m3fn, column-major slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs slot_token=None, # (num_slots,) int64 — per-slot token indices (default: arange) alpha=1.0, # fp32 scalar: D = alpha * A @ B (from stage_activation global scale) - sfb_prepacked=False, # True if weight_sf is already prepacked into CUTLASS layout ): """Per-expert grouped GEMM for MoE dispatch using CUTLASS NVFP4. Takes 1D per-slot expert IDs and token indices (pre-built by caller). - Returns slot-based output: one row per (token, topk) slot. + SFB weight scales are remapped per-expert inside CUTLASS on each call. + NO prepack cache — see nvfp4_mega_moe.py for rationale. For L1: x_fp4 has num_tokens rows, slot_token maps slots→rows. For L2: x_fp4 has num_slots rows, slot_token is just arange(num_slots). @@ -100,7 +100,7 @@ def cutlass_grouped_nvfp4_gemm( if MEGA_MOE_DEBUG: print(f"[cutlass_grouped_gemm] slots={num_slots} K={K} N={N} " - f"experts={num_experts} sfb_prepacked={sfb_prepacked}") + f"experts={num_experts}") slot_out = torch.empty(num_slots, N, dtype=torch.bfloat16, device=x_fp4.device) @@ -118,14 +118,21 @@ def cutlass_grouped_nvfp4_gemm( if MEGA_MOE_DEBUG and e < 3 and M_expert > 0: print(f"[GEMM-IN] expert={e} M={M_expert} N={N} K={K} " - f"w shape={expert_w.shape} sfb_prepacked={sfb_prepacked}") + f"w shape={expert_w.shape}") + + # Shape/dtype contract asserts — SFB bugs hide in silent shape mismatches + assert expert_x.shape == (M_expert, K // 2), f"expert_x shape {expert_x.shape} != ({M_expert}, {K // 2})" + assert expert_x_sf.shape == (M_expert, K // 16), f"SFA shape {expert_x_sf.shape} != ({M_expert}, {K // 16})" + assert expert_w.shape == (K // 2, N), f"expert_w shape {expert_w.shape} != ({K // 2}, {N})" + assert expert_w_sf.shape == (K // 16, N), f"SFB shape {expert_w_sf.shape} != ({K // 16}, {N})" + assert expert_x_sf.dtype == torch.float8_e4m3fn, f"SFA dtype {expert_x_sf.dtype}" + assert expert_w_sf.dtype == torch.float8_e4m3fn, f"SFB dtype {expert_w_sf.dtype}" expert_out = cutlass_nvfp4_blockscaled_gemm( expert_x, expert_x_sf, expert_w, expert_w_sf, M_expert, N, K, alpha=alpha, - sfb_prepacked=sfb_prepacked, ) if MEGA_MOE_DEBUG: diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 12aae590..72576c24 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -35,10 +35,11 @@ def unpack_ue4m3_u32(x_u32): whose bits are 0x3F (~0.984), NOT the integer 63. CUDA doesn't implement bitwise ops on uint32, so we cast to int32 first. + Supports ND tensors — last dim is the packed dim (N words → N*4 float8 values). """ # CUDA uint32 lacks bitwise ops — use int32 x_i32 = x_u32.to(torch.int32) - M, N = x_i32.shape + *prefix, n_words = x_i32.shape # Extract 4 bytes, cast to uint8, then bit-reinterpret to float8_e4m3fn b0 = (x_i32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) @@ -46,12 +47,12 @@ def unpack_ue4m3_u32(x_u32): b2 = ((x_i32 >> 16) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) b3 = ((x_i32 >> 24) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) - # Interleave into (M, N*4) - out = torch.empty(M, N * 4, dtype=torch.float8_e4m3fn, device=x_u32.device) - out[:, 0::4] = b0 - out[:, 1::4] = b1 - out[:, 2::4] = b2 - out[:, 3::4] = b3 + # Interleave into (*prefix, n_words*4) + out = torch.empty(*prefix, n_words * 4, dtype=torch.float8_e4m3fn, device=x_u32.device) + out[..., 0::4] = b0 + out[..., 1::4] = b1 + out[..., 2::4] = b2 + out[..., 3::4] = b3 return out # CUTLASS native NVFP4 block-scaled GEMM (SM100 Blackwell) @@ -89,100 +90,14 @@ MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0")) # Main kernel entry points # --------------------------------------------------------------------------- -def _prepack_weight_sf(weight_sf, N, K, tag): - """Lazily prepack SFB weight scales into CUTLASS layout (once per tag). - - Returns a tensor of shape (E, sfb_size) with SFB already in CUTLASS - interleaved layout, skipping the per-call remap+memset+alloc. - """ - cache_key = ( - tag, - weight_sf.data_ptr(), - tuple(weight_sf.shape), - str(weight_sf.dtype), - weight_sf.device.index, - N, - K, - ) - # ───────────────────────────────────────────────────────────────────── - # PREPACK CACHE — LRU with configurable max size - # - # Each prepacked SFB tensor is ~1.75 GiB per rank. With 61 MoE layers - # × 2 (L1+L2), an unbounded cache would consume ~214 GiB — well beyond - # B200 capacity. Since vLLM calls layers sequentially, only 2 entries - # are needed at a time (current layer's L1 + L2). - # - # WARNING: If you enable CUDA graphs in the future, multiple layers may - # be captured in a single graph replay, and the prepacked tensors must - # remain alive for the graph's lifetime. In that case, increase - # MEGA_MOE_PREPACK_CACHE_MAX to cover all captured layers, or switch - # to a persistent pre-allocation scheme. The default of 2 will cause - # use-after-free on evicted entries if CUDA graphs span >1 layer. - # ───────────────────────────────────────────────────────────────────── - _max_cache = int(os.environ.get('MEGA_MOE_PREPACK_CACHE_MAX', '2')) - - if not hasattr(_prepack_weight_sf, '_cache'): - _prepack_weight_sf._cache = {} - _prepack_weight_sf._cache_order = [] # LRU order - if cache_key in _prepack_weight_sf._cache: - # Move to end (most recently used) - _prepack_weight_sf._cache_order.remove(cache_key) - _prepack_weight_sf._cache_order.append(cache_key) - return _prepack_weight_sf._cache[cache_key] - - assert weight_sf.dtype == torch.float8_e4m3fn, weight_sf.dtype - - from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import prepack_sfb - - E = weight_sf.shape[0] - # M_for_layout controls CUTLASS SFB layout sizing. - # ASSUMPTION: SFB layout size is M-independent (CUTLASS tiling is over M - # but the scale factor block structure depends on N,K only). If this is - # wrong, we need to prepack per-expert with actual M. Verified only for - # M=128 — TODO: test with M=1, M=256 to confirm. - M_for_layout = 128 - - # Pre-allocate output tensor and fill in-place to avoid 2× peak memory - # (torch.stack would hold all expert tensors + the final stack = ~3.5 GiB) - packed0 = prepack_sfb(weight_sf[0], M_for_layout, N, K) - packed = torch.empty( - (E, *packed0.shape), - dtype=packed0.dtype, - device=packed0.device, - ) - packed[0].copy_(packed0) - del packed0 - - for e in range(1, E): - tmp = prepack_sfb(weight_sf[e], M_for_layout, N, K) - packed[e].copy_(tmp) - del tmp - - packed = packed.contiguous() - _prepack_weight_sf._cache[cache_key] = packed - _prepack_weight_sf._cache_order.append(cache_key) - - # Evict oldest entries — keep only _max_cache entries - while len(_prepack_weight_sf._cache) > _max_cache: - oldest = _prepack_weight_sf._cache_order.pop(0) - del _prepack_weight_sf._cache[oldest] - - if MEGA_MOE_DEBUG: - print(f"[PREPACK] {tag}: E={E} N={N} K={K} packed_shape={packed.shape} " - f"(was {weight_sf.shape})") - - return packed - - def nvfp4_mega_moe_l1( x_fp4, # (num_tokens, K//2) int8 packed E2M1 x_sf, # (num_tokens, sf_k_groups) float8_e4m3fn l1_weights, # (E_per_rank, K//2, 2*INTER) int8, column-major for CUTLASS - l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major — or prepacked + l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs slot_token, # (num_slots,) int64 — token index per slot alpha=1.0, # fp32 scalar from stage_activation global scale - sfb_prepacked=False, # True if l1_scales is prepacked CUTLASS layout ): """L1 GEMM: gate_up_proj — slot-based, no routing weights. @@ -198,10 +113,8 @@ def nvfp4_mega_moe_l1( print(f"[nvfp4_moe_l1] tokens={x_fp4.shape[0]} K={K} N={N} slots={slot_expert_ids.shape[0]}") x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf - if not sfb_prepacked: - w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales - else: - w_sf_fp8 = l1_scales # already prepacked, skip unpack + w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales + assert w_sf_fp8.dtype == torch.float8_e4m3fn, f"l1_scales after unpack dtype={w_sf_fp8.dtype}" slot_out, slot_token = cutlass_grouped_nvfp4_gemm( x_fp4, x_sf_fp8, @@ -209,7 +122,6 @@ def nvfp4_mega_moe_l1( slot_expert_ids, # 1D per-slot expert IDs slot_token, # 1D per-slot token indices alpha=alpha, - sfb_prepacked=sfb_prepacked, ) print(f"[L1-GEMM-OUT] slots={slot_out.shape[0]} N={N} amax={slot_out.abs().max().item():.4e} mean={slot_out.float().mean().item():.4e}") return slot_out, slot_token @@ -219,11 +131,10 @@ def nvfp4_mega_moe_l2( x_fp4, # (num_slots, INTER//2) int8 packed E2M1 x_sf, # (num_slots, sf_k_groups) float8_e4m3fn l2_weights, # (E_per_rank, INTER//2, HIDDEN) int8, column-major for CUTLASS - l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major — or prepacked + l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs (from L1 routing) slot_token, # (num_slots,) int64 — token index per slot (from L1) alpha=1.0, # fp32 scalar from stage_activation global scale - sfb_prepacked=False, # True if l2_scales is prepacked CUTLASS layout ): """L2 GEMM: down_proj — slot-based, no routing weights. @@ -237,19 +148,14 @@ def nvfp4_mega_moe_l2( print(f"[nvfp4_moe_l2] slots={x_fp4.shape[0]} K={K} N={N} native=1") x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf - if not sfb_prepacked: - w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales - else: - w_sf_fp8 = l2_scales # already prepacked - - # slot_expert_ids passed directly from L1 routing — no rebuild needed + w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales + assert w_sf_fp8.dtype == torch.float8_e4m3fn, f"l2_scales after unpack dtype={w_sf_fp8.dtype}" slot_out, _ = cutlass_grouped_nvfp4_gemm( x_fp4, x_sf_fp8, l2_weights, w_sf_fp8, slot_expert_ids, # 1D per-slot expert IDs — GEMM handles directly alpha=alpha, - sfb_prepacked=sfb_prepacked, ) return slot_out # (num_slots, HIDDEN) bfloat16 @@ -421,21 +327,22 @@ def nvfp4_mega_moe_full( assert slot_token.numel() == num_slots assert slot_weight.numel() == num_slots - # Prepack SFB weight scales into CUTLASS layout (lazy, once per layer) - l1_N = l1_w.shape[2] - l1_K = l1_w.shape[1] * 2 - l1_sf_prepacked = _prepack_weight_sf(l1_sf, l1_N, l1_K, "l1") + # SFB weight scales are remapped per-expert inside CUTLASS on each call. + # ───────────────────────────────────────────────────────────────────── + # NO PREPACK CACHE — see README for rationale. + # DO NOT add a prepack cache. Previous attempts caused: + # - OOM: ~1.75 GiB per prepacked tensor × 61 layers = 214 GiB + # - Peak memory 2× during torch.stack before eviction + # - CUDA graph use-after-free on evicted entries + # - M_for_layout=128 assumption (unverified M-independence) + # The SFB remap is a small scatter kernel (~µs) — not the bottleneck. + # ───────────────────────────────────────────────────────────────────── - l2_N = l2_w.shape[2] - l2_K = l2_w.shape[1] * 2 - l2_sf_prepacked = _prepack_weight_sf(l2_sf, l2_N, l2_K, "l2") - - # Step 2: L1 GEMM — slot-based, no routing weights, prepacked SFB + # Step 2: L1 GEMM — slot-based, no routing weights l1_slots, _ = nvfp4_mega_moe_l1( - x_fp4, x_sf, l1_w, l1_sf_prepacked, + x_fp4, x_sf, l1_w, l1_sf, slot_expert_local, slot_token, alpha=l1_alpha, - sfb_prepacked=True, ) # (num_slots, 2*INTER) bfloat16 # Post-L1 shape asserts @@ -469,12 +376,11 @@ def nvfp4_mega_moe_full( _l2gs = l2_global_scale if isinstance(l2_global_scale, float) else l2_global_scale.item() print(f"[ALPHA L2] alpha={_l2gs:.4e} l1_sf range [{_l1sf_f32.min().item():.4e}, {_l1sf_f32.max().item():.4e}]") - # Step 5: L2 GEMM — slot-based, no routing weights, prepacked SFB + # Step 5: L2 GEMM — slot-based, no routing weights l2_slots = nvfp4_mega_moe_l2( - l1_fp4, l1_sf_out, l2_w, l2_sf_prepacked, + l1_fp4, l1_sf_out, l2_w, l2_sf, slot_expert_local, slot_token, alpha=l2_alpha, - sfb_prepacked=True, ) # (num_slots, HIDDEN) bfloat16 if MEGA_MOE_DEBUG: