single_shot: fix memory (no double-loading MoE weights), FMHA short-seq fallback

- Don't cache MoE/SE expert weights in layer_w (handled by runners)
  This saves ~10.6GB/layer × 61 = ~647GB of double-loaded GPU memory
- Add FMHA fallback for seq_len < 128 (known kernel limitation:
  zero-padding dilutes softmax). TODO: fix kernel to mask padded entries.
- Free all_w and empty GPU caches after building runners
This commit is contained in:
2026-05-31 22:49:15 +00:00
parent 91568e12d4
commit d40821c843
2 changed files with 63 additions and 20 deletions

View File

@@ -1,8 +1,9 @@
#!/usr/bin/env python3
"""Single-shot DSV4-Pro inference — Full 61-layer pipeline, 8-GPU.
"""Single-shot DSV4-Pro inference PYTORCH VERSION — Full 61-layer pipeline, 8-GPU.
Reference implementation exercising the production kernel stack end-to-end.
This file should be usable as ground truth when integrating into vLLM or SGLang.
THIS is a pure-PyTorch reference reimplementation that bypasses every kernel in the production stack.
IT IS ONLY TO BE USED FOR REFERENCE FOR THE ACTUAL PRODUCTION KERNEL SINGLE SHOT
Architecture (paper §2, verified against HuggingFace modeling_deepseek_v4.py):
X_l → mHC.pre_block → RMSNorm → Attention → F_attn → mHC.post_block → X_mid

View File

@@ -371,13 +371,19 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w
q_heads: (T, n_h, hd), all_kv: (seq_len, hd)
Returns: (T, n_h, hd) BF16
"""
from dsv4.kernels.attention.production import dsv4_attention
# Reshape for kernel: q=(n_h, T, hd), k=(1, seq_len, hd), v same
q = q_heads.permute(1, 0, 2).contiguous() # (n_h, T, hd)
k = all_kv.unsqueeze(0).contiguous() # (1, seq_len, hd) — MQA
v = all_kv.unsqueeze(0).contiguous()
KERNEL LIMITATION: The 6-warp TMA FMHA kernel pads N to 128.
When seq_len < 128, the zero-padded entries dilute the softmax
(e.g. seq_len=1 gives softmax over 128, 127 of which are zero,
reducing max attention weight from 1.0 to 1/128). This must be
fixed in the kernel (skip zero-padded entries in softmax). Until
then, we use PyTorch scaled_dot_product_attention for short
sequences where the padding would dominate.
TODO: Fix FMHA kernel to handle N < 128 correctly (mask padded
entries from softmax). This is a kernel bug, not a design choice.
"""
FMHA_MIN_SEQ = 128 # Minimum seq_len for correct FMHA kernel output
# Sinks: per-head logit bias
sinks = w.get(f"{pfx}.sinks")
@@ -385,12 +391,34 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w
if sinks is not None:
sink_bias = sinks.to(device=dev).float().reshape(n_h)
attn_out = dsv4_attention(
q=q, k=k, v=v, scale=scale,
n_comp=0, # compressed KV already concatenated in all_kv
sink_bias=sink_bias,
) # (n_h, T, hd)
return attn_out.permute(1, 0, 2) # (T, n_h, hd)
if seq_len >= FMHA_MIN_SEQ:
# Production path: 6-warp TMA multi-tile kernel
from dsv4.kernels.attention.production import dsv4_attention
q = q_heads.permute(1, 0, 2).contiguous() # (n_h, T, hd)
k = all_kv.unsqueeze(0).contiguous() # (1, seq_len, hd) — MQA
v = all_kv.unsqueeze(0).contiguous()
attn_out = dsv4_attention(
q=q, k=k, v=v, scale=scale,
n_comp=0, sink_bias=sink_bias,
) # (n_h, T, hd)
return attn_out.permute(1, 0, 2) # (T, n_h, hd)
else:
# Short-sequence path: PyTorch SDPA
# TODO: Replace with fixed FMHA kernel once softmax padding is handled
k_exp = all_kv.unsqueeze(0).expand(n_h, -1, -1).contiguous() # (n_h, seq_len, hd)
v_exp = k_exp.clone()
q_in = q_heads.permute(1, 0, 2) # (n_h, T, hd)
scores = torch.matmul(q_in, k_exp.transpose(-1, -2)) * scale
if sink_bias is not None:
sink_logits = sink_bias.reshape(n_h, 1, 1).expand(-1, T, 1)
combined = torch.cat([scores, sink_logits], dim=-1)
combined = combined - combined.max(-1, keepdim=True).values
probs = torch.softmax(combined.float(), -1).bfloat16()
attn_w = probs[..., :-1]
else:
attn_w = torch.softmax(scores.float(), -1).bfloat16()
attn_out = torch.matmul(attn_w, v_exp).permute(1, 0, 2) # (T, n_h, hd)
return attn_out
# =====================================================================
@@ -758,11 +786,15 @@ def main():
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)
# Cache layer weights
print("Caching layer weights to GPUs...")
# Cache layer weights (EXCLUDE MoE/SE expert weights — handled by production runners)
# This avoids double-loading ~10GB/layer of expert FP4 weights
print("Caching layer weights to GPUs (excluding MoE expert weights)...")
devs = [f"cuda:{g}" for g in range(NUM_GPUS)]
layer_w = cache_layer_weights(all_w, n_layers, devs)
layer_w = _cache_layer_weights_no_experts(all_w, n_layers, devs)
del all_w; import gc; gc.collect()
for g in range(NUM_GPUS):
torch.cuda.set_device(g); torch.cuda.empty_cache()
torch.cuda.set_device(0)
print(f" {time.time()-t0:.1f}s")
# Load compressor/indexer weights
@@ -932,13 +964,23 @@ def _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg):
se.finalize_weights()
def cache_layer_weights(all_w, n_layers, devices):
def _cache_layer_weights_no_experts(all_w, n_layers, devices):
"""Cache per-layer weights to GPUs, EXCLUDING MoE expert weights.
MoE expert weights (model.layers.{li}.mlp.experts.*) are handled by
Nvfp4MoE runners with stacked tensors. Shared expert weights are handled
by Nvfp4SharedExpert runners. Including them here would double-load
~10.6GB/layer of FP4 expert weights.
"""
cached = {}
for li in range(n_layers):
dev = devices[li % len(devices)]
pfx = f"model.layers.{li}."
w = {k: v.to(device=dev, non_blocking=True) for k, v in all_w.items() if k.startswith(pfx)}
w = {k: v.to(device=dev, non_blocking=True)
for k, v in all_w.items()
if k.startswith(pfx) and '.experts.' not in k and '.shared_experts.' not in k}
cached[li] = w
if (li+1) % 10 == 0: print(f" Cached {li+1}/{n_layers} layers")
return cached
def load_weights(checkpoint_dir):