WIP: remove prepack cache, remap SFB per-call inside CUTLASS
- Deleted _prepack_weight_sf() and all cache/LRU logic - L1/L2 pass raw scales, sfb_prepacked param removed - cutlass_grouped_nvfp4_gemm always uses remap path - README: big warning table explaining why prepack cache must not return - Updated all doc references NOT PUSHED YET — pending Mike review
This commit is contained in:
29
README.md
29
README.md
@@ -70,14 +70,21 @@ The kernel uses a **slot representation** instead of collapsing expert outputs e
|
||||
|
||||
The slot approach fixes both: SiLU+Mul happens per-slot, and routing weights are applied exactly once at the final `index_add_` scatter.
|
||||
|
||||
### Prepacked SFB (Weight Scale Factors)
|
||||
### SFB (Weight Scale Factors) — Remapped Per-Call, NOT Cached
|
||||
|
||||
Weight scale factors (SFB) are pre-remapped into CUTLASS interleaved layout once at first forward pass (lazily cached). This eliminates per-GEMM:
|
||||
- 1 `cudaMalloc`-ish allocation for SFB
|
||||
- 1 `cudaMemsetAsync` for SFB padding
|
||||
- 1 remap kernel launch for SFB
|
||||
Weight scale factors (SFB) are remapped from row-major to CUTLASS interleaved layout on every GEMM call. This is a lightweight scatter kernel (~µs) and is NOT the bottleneck compared to the GEMM itself.
|
||||
|
||||
The `cutlass_nvfp4_gemm_run_prepacked_sfb` C entry point accepts the prepacked SFB pointer directly. Only SFA (activation scales) is remapped dynamically — those change every forward pass.
|
||||
⚠️ **DO NOT ADD A PREPACK CACHE FOR SFB.** Previous attempts caused critical issues:
|
||||
|
||||
| Problem | Impact |
|
||||
|---------|--------|
|
||||
| **OOM** | ~1.75 GiB per prepacked tensor × 61 MoE layers × 2 (L1+L2) = ~214 GiB — exceeds B200 capacity |
|
||||
| **Peak memory 2×** | `torch.stack` held all expert tensors + final stack simultaneously before LRU eviction |
|
||||
| **CUDA graph trap** | LRU eviction frees tensors that CUDA graphs still reference → use-after-free → silent corruption or crash |
|
||||
| **M-dependent layout** | `prepack_sfb(M=128)` assumed SFB layout size is M-independent (never verified). If wrong, entire prepack is invalid |
|
||||
| **Cross-layer cache collision** | Tag-based cache (`"l1"`/`"l2"`) returned layer N-1's data for layer N. Fixed with data_ptr key, but the cache itself was the root problem |
|
||||
|
||||
The per-call remap costs microseconds. The cache cost was hours of debugging. Don't repeat this mistake.
|
||||
|
||||
---
|
||||
|
||||
@@ -117,7 +124,7 @@ The `cutlass_nvfp4_gemm_run_prepacked_sfb` C entry point accepts the prepacked S
|
||||
6. Profile run (warmup)
|
||||
└─ First forward pass to allocate KV cache, etc.
|
||||
└─ This is where the CUTLASS GEMM first executes
|
||||
└─ SFB weight scales are prepacked into CUTLASS layout (lazy, cached)
|
||||
└─ SFB weight scales remapped per-expert inside CUTLASS (no cache)
|
||||
|
||||
7. Ready to serve
|
||||
```
|
||||
@@ -129,7 +136,7 @@ The `cutlass_nvfp4_gemm_run_prepacked_sfb` C entry point accepts the prepacked S
|
||||
```
|
||||
nvfp4_megamoe_kernel/
|
||||
├── __init__.py # Public API exports
|
||||
├── nvfp4_mega_moe.py # Main kernel: nvfp4_mega_moe_full, L1/L2, stage_activation, prepack
|
||||
├── nvfp4_mega_moe.py # Main kernel: nvfp4_mega_moe_full, L1/L2, stage_activation
|
||||
├── weight_transform.py # Weight prep: fold global scale, pack UE4M3
|
||||
├── symm_buffer.py # GPU buffer allocation for MoE dispatch
|
||||
│
|
||||
@@ -150,9 +157,9 @@ nvfp4_megamoe_kernel/
|
||||
|------|-------------|--------------|
|
||||
| `weight_transform.py` | Once at startup (weight loading) | Takes raw NVFP4 checkpoint weights, folds global scales into block scales. Returns scales as `float8_e4m3fn` (not packed uint32). Output: `((l1_w, l1_sf), (l2_w, l2_sf))` |
|
||||
| `symm_buffer.py` | Once at startup (buffer alloc) | Pre-allocates GPU tensors for activations, scales, routing data, and all-reduce. These persist across forward passes. |
|
||||
| `nvfp4_mega_moe.py` | Every forward pass | Orchestrates the MoE: reads from symm buffer → build slot mapping → L1 GEMM → SiLU+Mul per-slot → re-quantize → L2 GEMM → final index_add_ scatter with routing weights. Contains `stage_activation` (BF16→FP4), `unpack_ue4m3_u32`, and `_prepack_weight_sf` (lazy SFB prepack). |
|
||||
| `cutlass_nvfp4_gemm/kernel.py` | Every forward pass (called by nvfp4_mega_moe) | Slot-based per-expert loop: gather slots for each expert, call CUTLASS GEMM (with prepacked SFB), write results to slot buffer. No routing weights — caller handles scatter. |
|
||||
| `cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu` | Every forward pass (CUDA kernel) | The actual CUTLASS kernel: native NVFP4 block-scaled GEMM + GPU-side SFA remap. SFB remap done once at prepack time. Two GEMM entry points: standard (both remap) and prepacked-sfb (SFA remap only). |
|
||||
| `nvfp4_mega_moe.py` | Every forward pass | Orchestrates the MoE: reads from symm buffer → build slot mapping → L1 GEMM → SiLU+Mul per-slot → re-quantize → L2 GEMM → final index_add_ scatter with routing weights. Contains `stage_activation` (BF16→FP4) and `unpack_ue4m3_u32`. NO prepack cache — SFB remapped per-call inside CUTLASS. |
|
||||
| `cutlass_nvfp4_gemm/kernel.py` | Every forward pass (called by nvfp4_mega_moe) | Slot-based per-expert loop: gather slots for each expert, call CUTLASS GEMM (SFB remapped inside C extension), write results to slot buffer. No routing weights — caller handles scatter. |
|
||||
| `cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu` | Every forward pass (CUDA kernel) | The actual CUTLASS kernel: native NVFP4 block-scaled GEMM + GPU-side SFA and SFB remap. |
|
||||
| `cutlass_nvfp4_gemm/sf_layout.py` | Reference only | Documents the CUTLASS SfAtom layout. Not used at runtime (remap is in CUDA). |
|
||||
|
||||
---
|
||||
|
||||
@@ -170,8 +170,8 @@ int cutlass_nvfp4_gemm_run(
|
||||
using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB;
|
||||
using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF;
|
||||
|
||||
int sfa_size = cute::size(layout_SFA);
|
||||
int sfb_size = cute::size(layout_SFB);
|
||||
int sfa_size = cute::cosize(layout_SFA);
|
||||
int sfb_size = cute::cosize(layout_SFB);
|
||||
int K_sf = K / InputSFVectorSize;
|
||||
|
||||
cutlass::device_memory::allocation<ElementSF> sfa_cutlass(sfa_size);
|
||||
@@ -223,7 +223,7 @@ extern "C" int cutlass_nvfp4_sfb_size(
|
||||
) {
|
||||
using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
LayoutSFB layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
*out_size = cute::size(layout_SFB);
|
||||
*out_size = cute::cosize(layout_SFB);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -237,7 +237,7 @@ extern "C" int cutlass_nvfp4_prepack_sfb_run(
|
||||
LayoutSFB layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
|
||||
using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF;
|
||||
|
||||
int sfb_size = cute::size(layout_SFB);
|
||||
int sfb_size = cute::cosize(layout_SFB);
|
||||
int K_sf = K / InputSFVectorSize;
|
||||
|
||||
cudaMemsetAsync(static_cast<ElementSF*>(SFB_cutlass_ptr), 0, sfb_size * sizeof(ElementSF), stream);
|
||||
@@ -278,7 +278,7 @@ extern "C" int cutlass_nvfp4_gemm_run_prepacked_sfb(
|
||||
using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB;
|
||||
using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF;
|
||||
|
||||
int sfa_size = cute::size(layout_SFA);
|
||||
int sfa_size = cute::cosize(layout_SFA);
|
||||
int K_sf = K / InputSFVectorSize;
|
||||
|
||||
// Only remap SFA (activation scales) — SFB is prepacked
|
||||
|
||||
@@ -57,16 +57,16 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
x_fp4, # (num_slots_or_tokens, K_half) int8 packed E2M1
|
||||
x_sf, # (num_slots_or_tokens, sf_k) float8_e4m3fn block scales
|
||||
weights, # (E_per_rank, K_half, N) int8 packed E2M1, column-major for CUTLASS
|
||||
weight_sf, # (E_per_rank, sf_k, N) float8_e4m3fn, column-major — or prepacked (E_per_rank, sfb_size) if sfb_prepacked=True
|
||||
weight_sf, # (E_per_rank, sf_k, N) float8_e4m3fn, column-major
|
||||
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs
|
||||
slot_token=None, # (num_slots,) int64 — per-slot token indices (default: arange)
|
||||
alpha=1.0, # fp32 scalar: D = alpha * A @ B (from stage_activation global scale)
|
||||
sfb_prepacked=False, # True if weight_sf is already prepacked into CUTLASS layout
|
||||
):
|
||||
"""Per-expert grouped GEMM for MoE dispatch using CUTLASS NVFP4.
|
||||
|
||||
Takes 1D per-slot expert IDs and token indices (pre-built by caller).
|
||||
Returns slot-based output: one row per (token, topk) slot.
|
||||
SFB weight scales are remapped per-expert inside CUTLASS on each call.
|
||||
NO prepack cache — see nvfp4_mega_moe.py for rationale.
|
||||
|
||||
For L1: x_fp4 has num_tokens rows, slot_token maps slots→rows.
|
||||
For L2: x_fp4 has num_slots rows, slot_token is just arange(num_slots).
|
||||
@@ -100,7 +100,7 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[cutlass_grouped_gemm] slots={num_slots} K={K} N={N} "
|
||||
f"experts={num_experts} sfb_prepacked={sfb_prepacked}")
|
||||
f"experts={num_experts}")
|
||||
|
||||
slot_out = torch.empty(num_slots, N, dtype=torch.bfloat16, device=x_fp4.device)
|
||||
|
||||
@@ -118,14 +118,21 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
|
||||
if MEGA_MOE_DEBUG and e < 3 and M_expert > 0:
|
||||
print(f"[GEMM-IN] expert={e} M={M_expert} N={N} K={K} "
|
||||
f"w shape={expert_w.shape} sfb_prepacked={sfb_prepacked}")
|
||||
f"w shape={expert_w.shape}")
|
||||
|
||||
# Shape/dtype contract asserts — SFB bugs hide in silent shape mismatches
|
||||
assert expert_x.shape == (M_expert, K // 2), f"expert_x shape {expert_x.shape} != ({M_expert}, {K // 2})"
|
||||
assert expert_x_sf.shape == (M_expert, K // 16), f"SFA shape {expert_x_sf.shape} != ({M_expert}, {K // 16})"
|
||||
assert expert_w.shape == (K // 2, N), f"expert_w shape {expert_w.shape} != ({K // 2}, {N})"
|
||||
assert expert_w_sf.shape == (K // 16, N), f"SFB shape {expert_w_sf.shape} != ({K // 16}, {N})"
|
||||
assert expert_x_sf.dtype == torch.float8_e4m3fn, f"SFA dtype {expert_x_sf.dtype}"
|
||||
assert expert_w_sf.dtype == torch.float8_e4m3fn, f"SFB dtype {expert_w_sf.dtype}"
|
||||
|
||||
expert_out = cutlass_nvfp4_blockscaled_gemm(
|
||||
expert_x, expert_x_sf,
|
||||
expert_w, expert_w_sf,
|
||||
M_expert, N, K,
|
||||
alpha=alpha,
|
||||
sfb_prepacked=sfb_prepacked,
|
||||
)
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
|
||||
@@ -35,10 +35,11 @@ def unpack_ue4m3_u32(x_u32):
|
||||
whose bits are 0x3F (~0.984), NOT the integer 63.
|
||||
|
||||
CUDA doesn't implement bitwise ops on uint32, so we cast to int32 first.
|
||||
Supports ND tensors — last dim is the packed dim (N words → N*4 float8 values).
|
||||
"""
|
||||
# CUDA uint32 lacks bitwise ops — use int32
|
||||
x_i32 = x_u32.to(torch.int32)
|
||||
M, N = x_i32.shape
|
||||
*prefix, n_words = x_i32.shape
|
||||
|
||||
# Extract 4 bytes, cast to uint8, then bit-reinterpret to float8_e4m3fn
|
||||
b0 = (x_i32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
|
||||
@@ -46,12 +47,12 @@ def unpack_ue4m3_u32(x_u32):
|
||||
b2 = ((x_i32 >> 16) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
|
||||
b3 = ((x_i32 >> 24) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
|
||||
|
||||
# Interleave into (M, N*4)
|
||||
out = torch.empty(M, N * 4, dtype=torch.float8_e4m3fn, device=x_u32.device)
|
||||
out[:, 0::4] = b0
|
||||
out[:, 1::4] = b1
|
||||
out[:, 2::4] = b2
|
||||
out[:, 3::4] = b3
|
||||
# Interleave into (*prefix, n_words*4)
|
||||
out = torch.empty(*prefix, n_words * 4, dtype=torch.float8_e4m3fn, device=x_u32.device)
|
||||
out[..., 0::4] = b0
|
||||
out[..., 1::4] = b1
|
||||
out[..., 2::4] = b2
|
||||
out[..., 3::4] = b3
|
||||
return out
|
||||
|
||||
# CUTLASS native NVFP4 block-scaled GEMM (SM100 Blackwell)
|
||||
@@ -89,100 +90,14 @@ MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0"))
|
||||
# Main kernel entry points
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _prepack_weight_sf(weight_sf, N, K, tag):
|
||||
"""Lazily prepack SFB weight scales into CUTLASS layout (once per tag).
|
||||
|
||||
Returns a tensor of shape (E, sfb_size) with SFB already in CUTLASS
|
||||
interleaved layout, skipping the per-call remap+memset+alloc.
|
||||
"""
|
||||
cache_key = (
|
||||
tag,
|
||||
weight_sf.data_ptr(),
|
||||
tuple(weight_sf.shape),
|
||||
str(weight_sf.dtype),
|
||||
weight_sf.device.index,
|
||||
N,
|
||||
K,
|
||||
)
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# PREPACK CACHE — LRU with configurable max size
|
||||
#
|
||||
# Each prepacked SFB tensor is ~1.75 GiB per rank. With 61 MoE layers
|
||||
# × 2 (L1+L2), an unbounded cache would consume ~214 GiB — well beyond
|
||||
# B200 capacity. Since vLLM calls layers sequentially, only 2 entries
|
||||
# are needed at a time (current layer's L1 + L2).
|
||||
#
|
||||
# WARNING: If you enable CUDA graphs in the future, multiple layers may
|
||||
# be captured in a single graph replay, and the prepacked tensors must
|
||||
# remain alive for the graph's lifetime. In that case, increase
|
||||
# MEGA_MOE_PREPACK_CACHE_MAX to cover all captured layers, or switch
|
||||
# to a persistent pre-allocation scheme. The default of 2 will cause
|
||||
# use-after-free on evicted entries if CUDA graphs span >1 layer.
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
_max_cache = int(os.environ.get('MEGA_MOE_PREPACK_CACHE_MAX', '2'))
|
||||
|
||||
if not hasattr(_prepack_weight_sf, '_cache'):
|
||||
_prepack_weight_sf._cache = {}
|
||||
_prepack_weight_sf._cache_order = [] # LRU order
|
||||
if cache_key in _prepack_weight_sf._cache:
|
||||
# Move to end (most recently used)
|
||||
_prepack_weight_sf._cache_order.remove(cache_key)
|
||||
_prepack_weight_sf._cache_order.append(cache_key)
|
||||
return _prepack_weight_sf._cache[cache_key]
|
||||
|
||||
assert weight_sf.dtype == torch.float8_e4m3fn, weight_sf.dtype
|
||||
|
||||
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import prepack_sfb
|
||||
|
||||
E = weight_sf.shape[0]
|
||||
# M_for_layout controls CUTLASS SFB layout sizing.
|
||||
# ASSUMPTION: SFB layout size is M-independent (CUTLASS tiling is over M
|
||||
# but the scale factor block structure depends on N,K only). If this is
|
||||
# wrong, we need to prepack per-expert with actual M. Verified only for
|
||||
# M=128 — TODO: test with M=1, M=256 to confirm.
|
||||
M_for_layout = 128
|
||||
|
||||
# Pre-allocate output tensor and fill in-place to avoid 2× peak memory
|
||||
# (torch.stack would hold all expert tensors + the final stack = ~3.5 GiB)
|
||||
packed0 = prepack_sfb(weight_sf[0], M_for_layout, N, K)
|
||||
packed = torch.empty(
|
||||
(E, *packed0.shape),
|
||||
dtype=packed0.dtype,
|
||||
device=packed0.device,
|
||||
)
|
||||
packed[0].copy_(packed0)
|
||||
del packed0
|
||||
|
||||
for e in range(1, E):
|
||||
tmp = prepack_sfb(weight_sf[e], M_for_layout, N, K)
|
||||
packed[e].copy_(tmp)
|
||||
del tmp
|
||||
|
||||
packed = packed.contiguous()
|
||||
_prepack_weight_sf._cache[cache_key] = packed
|
||||
_prepack_weight_sf._cache_order.append(cache_key)
|
||||
|
||||
# Evict oldest entries — keep only _max_cache entries
|
||||
while len(_prepack_weight_sf._cache) > _max_cache:
|
||||
oldest = _prepack_weight_sf._cache_order.pop(0)
|
||||
del _prepack_weight_sf._cache[oldest]
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[PREPACK] {tag}: E={E} N={N} K={K} packed_shape={packed.shape} "
|
||||
f"(was {weight_sf.shape})")
|
||||
|
||||
return packed
|
||||
|
||||
|
||||
def nvfp4_mega_moe_l1(
|
||||
x_fp4, # (num_tokens, K//2) int8 packed E2M1
|
||||
x_sf, # (num_tokens, sf_k_groups) float8_e4m3fn
|
||||
l1_weights, # (E_per_rank, K//2, 2*INTER) int8, column-major for CUTLASS
|
||||
l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major — or prepacked
|
||||
l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major
|
||||
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs
|
||||
slot_token, # (num_slots,) int64 — token index per slot
|
||||
alpha=1.0, # fp32 scalar from stage_activation global scale
|
||||
sfb_prepacked=False, # True if l1_scales is prepacked CUTLASS layout
|
||||
):
|
||||
"""L1 GEMM: gate_up_proj — slot-based, no routing weights.
|
||||
|
||||
@@ -198,10 +113,8 @@ def nvfp4_mega_moe_l1(
|
||||
print(f"[nvfp4_moe_l1] tokens={x_fp4.shape[0]} K={K} N={N} slots={slot_expert_ids.shape[0]}")
|
||||
|
||||
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
|
||||
if not sfb_prepacked:
|
||||
w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales
|
||||
else:
|
||||
w_sf_fp8 = l1_scales # already prepacked, skip unpack
|
||||
w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales
|
||||
assert w_sf_fp8.dtype == torch.float8_e4m3fn, f"l1_scales after unpack dtype={w_sf_fp8.dtype}"
|
||||
|
||||
slot_out, slot_token = cutlass_grouped_nvfp4_gemm(
|
||||
x_fp4, x_sf_fp8,
|
||||
@@ -209,7 +122,6 @@ def nvfp4_mega_moe_l1(
|
||||
slot_expert_ids, # 1D per-slot expert IDs
|
||||
slot_token, # 1D per-slot token indices
|
||||
alpha=alpha,
|
||||
sfb_prepacked=sfb_prepacked,
|
||||
)
|
||||
print(f"[L1-GEMM-OUT] slots={slot_out.shape[0]} N={N} amax={slot_out.abs().max().item():.4e} mean={slot_out.float().mean().item():.4e}")
|
||||
return slot_out, slot_token
|
||||
@@ -219,11 +131,10 @@ def nvfp4_mega_moe_l2(
|
||||
x_fp4, # (num_slots, INTER//2) int8 packed E2M1
|
||||
x_sf, # (num_slots, sf_k_groups) float8_e4m3fn
|
||||
l2_weights, # (E_per_rank, INTER//2, HIDDEN) int8, column-major for CUTLASS
|
||||
l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major — or prepacked
|
||||
l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major
|
||||
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs (from L1 routing)
|
||||
slot_token, # (num_slots,) int64 — token index per slot (from L1)
|
||||
alpha=1.0, # fp32 scalar from stage_activation global scale
|
||||
sfb_prepacked=False, # True if l2_scales is prepacked CUTLASS layout
|
||||
):
|
||||
"""L2 GEMM: down_proj — slot-based, no routing weights.
|
||||
|
||||
@@ -237,19 +148,14 @@ def nvfp4_mega_moe_l2(
|
||||
print(f"[nvfp4_moe_l2] slots={x_fp4.shape[0]} K={K} N={N} native=1")
|
||||
|
||||
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
|
||||
if not sfb_prepacked:
|
||||
w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales
|
||||
else:
|
||||
w_sf_fp8 = l2_scales # already prepacked
|
||||
|
||||
# slot_expert_ids passed directly from L1 routing — no rebuild needed
|
||||
w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales
|
||||
assert w_sf_fp8.dtype == torch.float8_e4m3fn, f"l2_scales after unpack dtype={w_sf_fp8.dtype}"
|
||||
|
||||
slot_out, _ = cutlass_grouped_nvfp4_gemm(
|
||||
x_fp4, x_sf_fp8,
|
||||
l2_weights, w_sf_fp8,
|
||||
slot_expert_ids, # 1D per-slot expert IDs — GEMM handles directly
|
||||
alpha=alpha,
|
||||
sfb_prepacked=sfb_prepacked,
|
||||
)
|
||||
return slot_out # (num_slots, HIDDEN) bfloat16
|
||||
|
||||
@@ -421,21 +327,22 @@ def nvfp4_mega_moe_full(
|
||||
assert slot_token.numel() == num_slots
|
||||
assert slot_weight.numel() == num_slots
|
||||
|
||||
# Prepack SFB weight scales into CUTLASS layout (lazy, once per layer)
|
||||
l1_N = l1_w.shape[2]
|
||||
l1_K = l1_w.shape[1] * 2
|
||||
l1_sf_prepacked = _prepack_weight_sf(l1_sf, l1_N, l1_K, "l1")
|
||||
# SFB weight scales are remapped per-expert inside CUTLASS on each call.
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# NO PREPACK CACHE — see README for rationale.
|
||||
# DO NOT add a prepack cache. Previous attempts caused:
|
||||
# - OOM: ~1.75 GiB per prepacked tensor × 61 layers = 214 GiB
|
||||
# - Peak memory 2× during torch.stack before eviction
|
||||
# - CUDA graph use-after-free on evicted entries
|
||||
# - M_for_layout=128 assumption (unverified M-independence)
|
||||
# The SFB remap is a small scatter kernel (~µs) — not the bottleneck.
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
l2_N = l2_w.shape[2]
|
||||
l2_K = l2_w.shape[1] * 2
|
||||
l2_sf_prepacked = _prepack_weight_sf(l2_sf, l2_N, l2_K, "l2")
|
||||
|
||||
# Step 2: L1 GEMM — slot-based, no routing weights, prepacked SFB
|
||||
# Step 2: L1 GEMM — slot-based, no routing weights
|
||||
l1_slots, _ = nvfp4_mega_moe_l1(
|
||||
x_fp4, x_sf, l1_w, l1_sf_prepacked,
|
||||
x_fp4, x_sf, l1_w, l1_sf,
|
||||
slot_expert_local, slot_token,
|
||||
alpha=l1_alpha,
|
||||
sfb_prepacked=True,
|
||||
) # (num_slots, 2*INTER) bfloat16
|
||||
|
||||
# Post-L1 shape asserts
|
||||
@@ -469,12 +376,11 @@ def nvfp4_mega_moe_full(
|
||||
_l2gs = l2_global_scale if isinstance(l2_global_scale, float) else l2_global_scale.item()
|
||||
print(f"[ALPHA L2] alpha={_l2gs:.4e} l1_sf range [{_l1sf_f32.min().item():.4e}, {_l1sf_f32.max().item():.4e}]")
|
||||
|
||||
# Step 5: L2 GEMM — slot-based, no routing weights, prepacked SFB
|
||||
# Step 5: L2 GEMM — slot-based, no routing weights
|
||||
l2_slots = nvfp4_mega_moe_l2(
|
||||
l1_fp4, l1_sf_out, l2_w, l2_sf_prepacked,
|
||||
l1_fp4, l1_sf_out, l2_w, l2_sf,
|
||||
slot_expert_local, slot_token,
|
||||
alpha=l2_alpha,
|
||||
sfb_prepacked=True,
|
||||
) # (num_slots, HIDDEN) bfloat16
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
|
||||
Reference in New Issue
Block a user