single_shot: fix memory (no double-loading MoE weights), FMHA short-seq fallback
- Don't cache MoE/SE expert weights in layer_w (handled by runners) This saves ~10.6GB/layer × 61 = ~647GB of double-loaded GPU memory - Add FMHA fallback for seq_len < 128 (known kernel limitation: zero-padding dilutes softmax). TODO: fix kernel to mask padded entries. - Free all_w and empty GPU caches after building runners
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Single-shot DSV4-Pro inference — Full 61-layer pipeline, 8-GPU.
|
||||
"""Single-shot DSV4-Pro inference PYTORCH VERSION — Full 61-layer pipeline, 8-GPU.
|
||||
|
||||
Reference implementation exercising the production kernel stack end-to-end.
|
||||
This file should be usable as ground truth when integrating into vLLM or SGLang.
|
||||
THIS is a pure-PyTorch reference reimplementation that bypasses every kernel in the production stack.
|
||||
|
||||
IT IS ONLY TO BE USED FOR REFERENCE FOR THE ACTUAL PRODUCTION KERNEL SINGLE SHOT
|
||||
|
||||
Architecture (paper §2, verified against HuggingFace modeling_deepseek_v4.py):
|
||||
X_l → mHC.pre_block → RMSNorm → Attention → F_attn → mHC.post_block → X_mid
|
||||
|
||||
@@ -371,13 +371,19 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w
|
||||
|
||||
q_heads: (T, n_h, hd), all_kv: (seq_len, hd)
|
||||
Returns: (T, n_h, hd) BF16
|
||||
"""
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
|
||||
# Reshape for kernel: q=(n_h, T, hd), k=(1, seq_len, hd), v same
|
||||
q = q_heads.permute(1, 0, 2).contiguous() # (n_h, T, hd)
|
||||
k = all_kv.unsqueeze(0).contiguous() # (1, seq_len, hd) — MQA
|
||||
v = all_kv.unsqueeze(0).contiguous()
|
||||
KERNEL LIMITATION: The 6-warp TMA FMHA kernel pads N to 128.
|
||||
When seq_len < 128, the zero-padded entries dilute the softmax
|
||||
(e.g. seq_len=1 gives softmax over 128, 127 of which are zero,
|
||||
reducing max attention weight from 1.0 to 1/128). This must be
|
||||
fixed in the kernel (skip zero-padded entries in softmax). Until
|
||||
then, we use PyTorch scaled_dot_product_attention for short
|
||||
sequences where the padding would dominate.
|
||||
|
||||
TODO: Fix FMHA kernel to handle N < 128 correctly (mask padded
|
||||
entries from softmax). This is a kernel bug, not a design choice.
|
||||
"""
|
||||
FMHA_MIN_SEQ = 128 # Minimum seq_len for correct FMHA kernel output
|
||||
|
||||
# Sinks: per-head logit bias
|
||||
sinks = w.get(f"{pfx}.sinks")
|
||||
@@ -385,12 +391,34 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w
|
||||
if sinks is not None:
|
||||
sink_bias = sinks.to(device=dev).float().reshape(n_h)
|
||||
|
||||
attn_out = dsv4_attention(
|
||||
q=q, k=k, v=v, scale=scale,
|
||||
n_comp=0, # compressed KV already concatenated in all_kv
|
||||
sink_bias=sink_bias,
|
||||
) # (n_h, T, hd)
|
||||
return attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||||
if seq_len >= FMHA_MIN_SEQ:
|
||||
# Production path: 6-warp TMA multi-tile kernel
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
q = q_heads.permute(1, 0, 2).contiguous() # (n_h, T, hd)
|
||||
k = all_kv.unsqueeze(0).contiguous() # (1, seq_len, hd) — MQA
|
||||
v = all_kv.unsqueeze(0).contiguous()
|
||||
attn_out = dsv4_attention(
|
||||
q=q, k=k, v=v, scale=scale,
|
||||
n_comp=0, sink_bias=sink_bias,
|
||||
) # (n_h, T, hd)
|
||||
return attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||||
else:
|
||||
# Short-sequence path: PyTorch SDPA
|
||||
# TODO: Replace with fixed FMHA kernel once softmax padding is handled
|
||||
k_exp = all_kv.unsqueeze(0).expand(n_h, -1, -1).contiguous() # (n_h, seq_len, hd)
|
||||
v_exp = k_exp.clone()
|
||||
q_in = q_heads.permute(1, 0, 2) # (n_h, T, hd)
|
||||
scores = torch.matmul(q_in, k_exp.transpose(-1, -2)) * scale
|
||||
if sink_bias is not None:
|
||||
sink_logits = sink_bias.reshape(n_h, 1, 1).expand(-1, T, 1)
|
||||
combined = torch.cat([scores, sink_logits], dim=-1)
|
||||
combined = combined - combined.max(-1, keepdim=True).values
|
||||
probs = torch.softmax(combined.float(), -1).bfloat16()
|
||||
attn_w = probs[..., :-1]
|
||||
else:
|
||||
attn_w = torch.softmax(scores.float(), -1).bfloat16()
|
||||
attn_out = torch.matmul(attn_w, v_exp).permute(1, 0, 2) # (T, n_h, hd)
|
||||
return attn_out
|
||||
|
||||
|
||||
# =====================================================================
|
||||
@@ -758,11 +786,15 @@ def main():
|
||||
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
|
||||
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)
|
||||
|
||||
# Cache layer weights
|
||||
print("Caching layer weights to GPUs...")
|
||||
# Cache layer weights (EXCLUDE MoE/SE expert weights — handled by production runners)
|
||||
# This avoids double-loading ~10GB/layer of expert FP4 weights
|
||||
print("Caching layer weights to GPUs (excluding MoE expert weights)...")
|
||||
devs = [f"cuda:{g}" for g in range(NUM_GPUS)]
|
||||
layer_w = cache_layer_weights(all_w, n_layers, devs)
|
||||
layer_w = _cache_layer_weights_no_experts(all_w, n_layers, devs)
|
||||
del all_w; import gc; gc.collect()
|
||||
for g in range(NUM_GPUS):
|
||||
torch.cuda.set_device(g); torch.cuda.empty_cache()
|
||||
torch.cuda.set_device(0)
|
||||
print(f" {time.time()-t0:.1f}s")
|
||||
|
||||
# Load compressor/indexer weights
|
||||
@@ -932,13 +964,23 @@ def _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg):
|
||||
se.finalize_weights()
|
||||
|
||||
|
||||
def cache_layer_weights(all_w, n_layers, devices):
|
||||
def _cache_layer_weights_no_experts(all_w, n_layers, devices):
|
||||
"""Cache per-layer weights to GPUs, EXCLUDING MoE expert weights.
|
||||
|
||||
MoE expert weights (model.layers.{li}.mlp.experts.*) are handled by
|
||||
Nvfp4MoE runners with stacked tensors. Shared expert weights are handled
|
||||
by Nvfp4SharedExpert runners. Including them here would double-load
|
||||
~10.6GB/layer of FP4 expert weights.
|
||||
"""
|
||||
cached = {}
|
||||
for li in range(n_layers):
|
||||
dev = devices[li % len(devices)]
|
||||
pfx = f"model.layers.{li}."
|
||||
w = {k: v.to(device=dev, non_blocking=True) for k, v in all_w.items() if k.startswith(pfx)}
|
||||
w = {k: v.to(device=dev, non_blocking=True)
|
||||
for k, v in all_w.items()
|
||||
if k.startswith(pfx) and '.experts.' not in k and '.shared_experts.' not in k}
|
||||
cached[li] = w
|
||||
if (li+1) % 10 == 0: print(f" Cached {li+1}/{n_layers} layers")
|
||||
return cached
|
||||
|
||||
def load_weights(checkpoint_dir):
|
||||
|
||||
Reference in New Issue
Block a user