From d40821c84324b693e4208a72fabf8d07c165bef6 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 22:49:15 +0000 Subject: [PATCH] single_shot: fix memory (no double-loading MoE weights), FMHA short-seq fallback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Don't cache MoE/SE expert weights in layer_w (handled by runners) This saves ~10.6GB/layer × 61 = ~647GB of double-loaded GPU memory - Add FMHA fallback for seq_len < 128 (known kernel limitation: zero-padding dilutes softmax). TODO: fix kernel to mask padded entries. - Free all_w and empty GPU caches after building runners --- single_shot_PYTORCH_REFERENCE.py | 7 +-- single_shot_inference.py | 76 +++++++++++++++++++++++++------- 2 files changed, 63 insertions(+), 20 deletions(-) diff --git a/single_shot_PYTORCH_REFERENCE.py b/single_shot_PYTORCH_REFERENCE.py index 20e8a762..5f443223 100644 --- a/single_shot_PYTORCH_REFERENCE.py +++ b/single_shot_PYTORCH_REFERENCE.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 -"""Single-shot DSV4-Pro inference — Full 61-layer pipeline, 8-GPU. +"""Single-shot DSV4-Pro inference PYTORCH VERSION — Full 61-layer pipeline, 8-GPU. -Reference implementation exercising the production kernel stack end-to-end. -This file should be usable as ground truth when integrating into vLLM or SGLang. +THIS is a pure-PyTorch reference reimplementation that bypasses every kernel in the production stack. + +IT IS ONLY TO BE USED FOR REFERENCE FOR THE ACTUAL PRODUCTION KERNEL SINGLE SHOT Architecture (paper §2, verified against HuggingFace modeling_deepseek_v4.py): X_l → mHC.pre_block → RMSNorm → Attention → F_attn → mHC.post_block → X_mid diff --git a/single_shot_inference.py b/single_shot_inference.py index 2c3bcd30..db593cb9 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -371,13 +371,19 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w q_heads: (T, n_h, hd), all_kv: (seq_len, hd) Returns: (T, n_h, hd) BF16 - """ - from dsv4.kernels.attention.production import dsv4_attention - # Reshape for kernel: q=(n_h, T, hd), k=(1, seq_len, hd), v same - q = q_heads.permute(1, 0, 2).contiguous() # (n_h, T, hd) - k = all_kv.unsqueeze(0).contiguous() # (1, seq_len, hd) — MQA - v = all_kv.unsqueeze(0).contiguous() + KERNEL LIMITATION: The 6-warp TMA FMHA kernel pads N to 128. + When seq_len < 128, the zero-padded entries dilute the softmax + (e.g. seq_len=1 gives softmax over 128, 127 of which are zero, + reducing max attention weight from 1.0 to 1/128). This must be + fixed in the kernel (skip zero-padded entries in softmax). Until + then, we use PyTorch scaled_dot_product_attention for short + sequences where the padding would dominate. + + TODO: Fix FMHA kernel to handle N < 128 correctly (mask padded + entries from softmax). This is a kernel bug, not a design choice. + """ + FMHA_MIN_SEQ = 128 # Minimum seq_len for correct FMHA kernel output # Sinks: per-head logit bias sinks = w.get(f"{pfx}.sinks") @@ -385,12 +391,34 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w if sinks is not None: sink_bias = sinks.to(device=dev).float().reshape(n_h) - attn_out = dsv4_attention( - q=q, k=k, v=v, scale=scale, - n_comp=0, # compressed KV already concatenated in all_kv - sink_bias=sink_bias, - ) # (n_h, T, hd) - return attn_out.permute(1, 0, 2) # (T, n_h, hd) + if seq_len >= FMHA_MIN_SEQ: + # Production path: 6-warp TMA multi-tile kernel + from dsv4.kernels.attention.production import dsv4_attention + q = q_heads.permute(1, 0, 2).contiguous() # (n_h, T, hd) + k = all_kv.unsqueeze(0).contiguous() # (1, seq_len, hd) — MQA + v = all_kv.unsqueeze(0).contiguous() + attn_out = dsv4_attention( + q=q, k=k, v=v, scale=scale, + n_comp=0, sink_bias=sink_bias, + ) # (n_h, T, hd) + return attn_out.permute(1, 0, 2) # (T, n_h, hd) + else: + # Short-sequence path: PyTorch SDPA + # TODO: Replace with fixed FMHA kernel once softmax padding is handled + k_exp = all_kv.unsqueeze(0).expand(n_h, -1, -1).contiguous() # (n_h, seq_len, hd) + v_exp = k_exp.clone() + q_in = q_heads.permute(1, 0, 2) # (n_h, T, hd) + scores = torch.matmul(q_in, k_exp.transpose(-1, -2)) * scale + if sink_bias is not None: + sink_logits = sink_bias.reshape(n_h, 1, 1).expand(-1, T, 1) + combined = torch.cat([scores, sink_logits], dim=-1) + combined = combined - combined.max(-1, keepdim=True).values + probs = torch.softmax(combined.float(), -1).bfloat16() + attn_w = probs[..., :-1] + else: + attn_w = torch.softmax(scores.float(), -1).bfloat16() + attn_out = torch.matmul(attn_w, v_exp).permute(1, 0, 2) # (T, n_h, hd) + return attn_out # ===================================================================== @@ -758,11 +786,15 @@ def main(): if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev) if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev) - # Cache layer weights - print("Caching layer weights to GPUs...") + # Cache layer weights (EXCLUDE MoE/SE expert weights — handled by production runners) + # This avoids double-loading ~10GB/layer of expert FP4 weights + print("Caching layer weights to GPUs (excluding MoE expert weights)...") devs = [f"cuda:{g}" for g in range(NUM_GPUS)] - layer_w = cache_layer_weights(all_w, n_layers, devs) + layer_w = _cache_layer_weights_no_experts(all_w, n_layers, devs) del all_w; import gc; gc.collect() + for g in range(NUM_GPUS): + torch.cuda.set_device(g); torch.cuda.empty_cache() + torch.cuda.set_device(0) print(f" {time.time()-t0:.1f}s") # Load compressor/indexer weights @@ -932,13 +964,23 @@ def _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg): se.finalize_weights() -def cache_layer_weights(all_w, n_layers, devices): +def _cache_layer_weights_no_experts(all_w, n_layers, devices): + """Cache per-layer weights to GPUs, EXCLUDING MoE expert weights. + + MoE expert weights (model.layers.{li}.mlp.experts.*) are handled by + Nvfp4MoE runners with stacked tensors. Shared expert weights are handled + by Nvfp4SharedExpert runners. Including them here would double-load + ~10.6GB/layer of FP4 expert weights. + """ cached = {} for li in range(n_layers): dev = devices[li % len(devices)] pfx = f"model.layers.{li}." - w = {k: v.to(device=dev, non_blocking=True) for k, v in all_w.items() if k.startswith(pfx)} + w = {k: v.to(device=dev, non_blocking=True) + for k, v in all_w.items() + if k.startswith(pfx) and '.experts.' not in k and '.shared_experts.' not in k} cached[li] = w + if (li+1) % 10 == 0: print(f" Cached {li+1}/{n_layers} layers") return cached def load_weights(checkpoint_dir):