248 lines
11 KiB
Python
248 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""PART A diagnostic: full forward_attention pipeline comparison.
|
|
|
|
Tests each stage of the production attention pipeline against a PyTorch
|
|
reference for the first few layers. Identifies exactly where the pipeline
|
|
diverges from the reference.
|
|
|
|
Stages tested per layer:
|
|
1. Q projection (q_a → q_a_norm → q_b → q_b_norm)
|
|
2. KV projection + RoPE
|
|
3. KV cache append + compressor
|
|
4. KV gathering (compressed + SWA)
|
|
5. FMHA (production vs SDPA)
|
|
6. Inverse RoPE
|
|
7. Output projection (o_a + o_b)
|
|
8. Full forward_attention output vs reference
|
|
|
|
Uses REAL model weights and production values.
|
|
"""
|
|
import sys, os, time, math
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
# ── Helpers ──────────────────────────────────────────────────────
|
|
def cosine(a, b):
|
|
a, b = a.flatten().float(), b.flatten().float()
|
|
d = a @ b
|
|
na, nb = a.norm(), b.norm()
|
|
return (d / (na * nb + 1e-12)).item()
|
|
|
|
def rmsnorm(x, w, eps=1e-6):
|
|
dtype = x.dtype
|
|
x = x.float()
|
|
rms = x.pow(2).mean(-1, keepdim=True).add(eps).rsqrt()
|
|
return (x * rms).to(dtype) * w.to(dtype)
|
|
|
|
# ── Main ─────────────────────────────────────────────────────────
|
|
def main():
|
|
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
NUM_GPUS = 8
|
|
MAX_LAYERS = 3 # Test first 3 layers
|
|
|
|
print("=" * 70)
|
|
print("PART A DIAGNOSTIC: Full Attention Pipeline Comparison")
|
|
print(f"Model: {MODEL}, Layers: {MAX_LAYERS}, GPUs: {NUM_GPUS}")
|
|
print("=" * 70)
|
|
|
|
# ── Load model config ──
|
|
import json
|
|
with open(os.path.join(MODEL, "config.json")) as f:
|
|
cfg = json.load(f)
|
|
n_layers = cfg["num_hidden_layers"]
|
|
n_h = cfg["num_attention_heads"]
|
|
hd = cfg["head_dim"]
|
|
hidden = cfg["hidden_size"]
|
|
rd = cfg.get("qk_rope_head_dim", 64)
|
|
nope_dim = hd - rd
|
|
o_groups = cfg.get("o_groups", 16)
|
|
o_rank = cfg.get("o_lora_rank", 1024)
|
|
scale = 1.0 / math.sqrt(hd)
|
|
|
|
print(f"Config: {n_layers}L, {n_h}H, hd={hd}, rope={rd}, nope={nope_dim}")
|
|
print(f" o_groups={o_groups}, o_rank={o_rank}, hidden={hidden}")
|
|
|
|
# ── Load tokenizer ──
|
|
from transformers import AutoTokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
|
|
prompt = "The capital of France is"
|
|
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
|
print(f"Prompt: '{prompt}' → {len(input_ids)} tokens: {input_ids}")
|
|
|
|
# ── Load RoPE caches ──
|
|
from dsv4.ops.rope_cuda import build_rope_cache
|
|
rope_caches = {}
|
|
for gpu in range(NUM_GPUS):
|
|
torch.cuda.set_device(gpu)
|
|
rope_caches[gpu] = build_rope_cache(8192, hd, rd, device=f"cuda:{gpu}")
|
|
|
|
# ── Load weights and set up production layers ──
|
|
from single_shot_inference import (
|
|
load_layer_weights, setup_production_linear, setup_compressor,
|
|
setup_indexer, KVCache, mHCLayer, rmsnorm as prod_rmsnorm,
|
|
_apply_rope, forward_attention
|
|
)
|
|
|
|
# ── Process prefill tokens one by one ──
|
|
results = {}
|
|
for li in range(MAX_LAYERS):
|
|
gpu = li % NUM_GPUS
|
|
torch.cuda.set_device(gpu)
|
|
|
|
# Load weights for this layer
|
|
w, prod_lin, compressor, indexer = None, None, None, None
|
|
try:
|
|
w = load_layer_weights(MODEL, li, f"cuda:{gpu}")
|
|
prod_lin = setup_production_linear(w, li, cfg, f"cuda:{gpu}")
|
|
compressor = setup_compressor(w, li, cfg, f"cuda:{gpu}")
|
|
if compressor is not None and compressor.ratio == 4:
|
|
indexer = setup_indexer(w, li, cfg, f"cuda:{gpu}")
|
|
except Exception as e:
|
|
print(f" L{li}: Failed to load weights: {e}")
|
|
continue
|
|
|
|
pfx = f"model.layers.{li}.self_attn"
|
|
ratio = compressor.ratio if compressor is not None else 0
|
|
layer_type = "SWA" if ratio == 0 else ("CSA" if ratio == 4 else "HCA")
|
|
print(f"\nL{li} (gpu={gpu}, type={layer_type}, ratio={ratio})")
|
|
|
|
# Set up KV cache
|
|
kv_cache = KVCache(li, cfg, f"cuda:{gpu}")
|
|
mhc_attn = mHCLayer(li, "attn", cfg, f"cuda:{gpu}")
|
|
|
|
# Initialize mHC state
|
|
embed_w = torch.load(os.path.join(MODEL, "model.embed_tokens.weight.pt"),
|
|
map_location=f"cuda:{gpu}", weights_only=True).bfloat16()
|
|
embed_w = embed_w.to(f"cuda:{gpu}")
|
|
|
|
# Process each prefill token
|
|
X = None
|
|
for pi, tid in enumerate(input_ids):
|
|
tid_t = torch.tensor([tid], dtype=torch.long, device=f"cuda:{gpu}")
|
|
pos = torch.tensor([pi], dtype=torch.long, device=f"cuda:{gpu}")
|
|
|
|
if pi == 0:
|
|
X = mHCLayer.init_state(F.embedding(tid_t, embed_w))
|
|
else:
|
|
X = mHCLayer.init_state(F.embedding(tid_t, embed_w))
|
|
|
|
# Forward through attention for this layer
|
|
X_normed = rmsnorm(X, w.get(f"model.layers.{li}.input_layernorm.weight").to(f"cuda:{gpu}", torch.float32))
|
|
|
|
if pi == 0:
|
|
# First token: run forward_attention and capture intermediate values
|
|
# We need to run the full pipeline and compare
|
|
dev = f"cuda:{gpu}"
|
|
T = 1
|
|
|
|
# 1. Q projections
|
|
q_a = prod_lin['q_a'](X_normed)
|
|
q_norm_w = w.get(f"{pfx}.q_a_norm.weight")
|
|
q_a_norm = rmsnorm(q_a, q_norm_w.to(dev, torch.float32)) if q_norm_w is not None else q_a
|
|
q = prod_lin['q_b'](q_a_norm)
|
|
q = rmsnorm(q, w.get(f"{pfx}.q_b_norm.weight").to(dev, torch.float32)).bfloat16()
|
|
q_heads = q.reshape(T, n_h, hd)
|
|
q_heads = _apply_rope(q_heads, pos, *rope_caches[gpu], rd)
|
|
|
|
# 2. KV projection
|
|
kv = prod_lin['kv'](X_normed)
|
|
kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
|
|
if kv_norm_w is not None:
|
|
kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
|
|
kv_3d = kv.reshape(T, 1, hd)
|
|
kv_3d = _apply_rope(kv_3d, pos, *rope_caches[gpu], rd)
|
|
kv_roped = kv_3d.reshape(T, hd)
|
|
kv_cache.append_swa(kv_roped, pos)
|
|
|
|
# 3. Compression (if applicable)
|
|
comp_pos = None
|
|
if compressor is not None and compressor.ratio > 0:
|
|
comp_kv_fp32, comp_pos, _ = compressor.forward(X_normed, pos)
|
|
if comp_kv_fp32 is not None:
|
|
from dsv4.kernels.cuda.loader import get_cuda_module
|
|
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
|
nope_fp32 = comp_kv_fp32[:, :nope_dim].contiguous()
|
|
rope_bf16 = comp_kv_fp32[:, nope_dim:].bfloat16().contiguous()
|
|
rope_3d = rope_bf16.unsqueeze(1)
|
|
rope_3d = _apply_rope(rope_3d, comp_pos, *rope_caches[gpu], rd)
|
|
rope_bf16 = rope_3d.squeeze(1)
|
|
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
|
kv_cache.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16, comp_pos)
|
|
if compressor.is_csa and indexer is not None:
|
|
comp_idx_kv, _, _ = indexer.compressor.forward(X_normed, pos)
|
|
kv_cache.set_indexer_keys_fp8(comp_idx_kv)
|
|
|
|
# 4. Indexer (CSA)
|
|
topk_idx = None
|
|
if indexer is not None and ratio == 4:
|
|
topk_idx = indexer.forward(q_a, X_normed, kv_cache, pos, layer_idx=li)
|
|
|
|
# 5. Gather KV
|
|
swa_kv, _swa_pos = kv_cache.get_swa()
|
|
swa_len = swa_kv.shape[0]
|
|
if kv_cache.n_comp > 0:
|
|
if ratio == 4:
|
|
tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1).int()
|
|
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_selective(tk)
|
|
elif ratio > 4:
|
|
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_all()
|
|
else:
|
|
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_swa_only()
|
|
else:
|
|
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_swa_only()
|
|
seq_len = kv_nope_scale.shape[0]
|
|
|
|
print(f" Token 0: seq_len={seq_len} swa_len={swa_len} n_comp={kv_cache.n_comp}")
|
|
print(f" kv_nope_fp8 shape={tuple(kv_nope_fp8.shape)} dtype={kv_nope_fp8.dtype}")
|
|
print(f" kv_nope_scale shape={tuple(kv_nope_scale.shape)} dtype={kv_nope_scale.dtype}")
|
|
print(f" kv_rope_bf16 shape={tuple(kv_rope_bf16.shape)} dtype={kv_rope_bf16.dtype}")
|
|
else:
|
|
# Non-first token: just run through and build KV cache
|
|
dev = f"cuda:{gpu}"
|
|
T = 1
|
|
q_a = prod_lin['q_a'](X_normed)
|
|
q_norm_w = w.get(f"{pfx}.q_a_norm.weight")
|
|
q_a_norm = rmsnorm(q_a, q_norm_w.to(dev, torch.float32)) if q_norm_w is not None else q_a
|
|
q = prod_lin['q_b'](q_a_norm)
|
|
q = rmsnorm(q, w.get(f"{pfx}.q_b_norm.weight").to(dev, torch.float32)).bfloat16()
|
|
q_heads = q.reshape(T, n_h, hd)
|
|
q_heads = _apply_rope(q_heads, pos, *rope_caches[gpu], rd)
|
|
|
|
kv = prod_lin['kv'](X_normed)
|
|
kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
|
|
if kv_norm_w is not None:
|
|
kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
|
|
kv_3d = kv.reshape(T, 1, hd)
|
|
kv_3d = _apply_rope(kv_3d, pos, *rope_caches[gpu], rd)
|
|
kv_roped = kv_3d.reshape(T, hd)
|
|
kv_cache.append_swa(kv_roped, pos)
|
|
|
|
if compressor is not None and compressor.ratio > 0:
|
|
comp_kv_fp32, comp_pos, _ = compressor.forward(X_normed, pos)
|
|
if comp_kv_fp32 is not None:
|
|
from dsv4.kernels.cuda.loader import get_cuda_module
|
|
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
|
nope_fp32 = comp_kv_fp32[:, :nope_dim].contiguous()
|
|
rope_bf16 = comp_kv_fp32[:, nope_dim:].bfloat16().contiguous()
|
|
rope_3d = rope_bf16.unsqueeze(1)
|
|
rope_3d = _apply_rope(rope_3d, comp_pos, *rope_caches[gpu], rd)
|
|
rope_bf16 = rope_3d.squeeze(1)
|
|
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
|
kv_cache.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16, comp_pos)
|
|
if compressor.is_csa and indexer is not None:
|
|
comp_idx_kv, _, _ = indexer.compressor.forward(X_normed, pos)
|
|
kv_cache.set_indexer_keys_fp8(comp_idx_kv)
|
|
|
|
# mHC forward
|
|
# (simplified — the real single_shot uses forward_layer which handles mHC)
|
|
|
|
# After all prefill tokens, check KV state
|
|
print(f" L{li} after prefill: n_comp={kv_cache.n_comp} swa_len={kv_cache.get_swa()[0].shape[0]}")
|
|
|
|
print("\n" + "=" * 70)
|
|
print("DONE")
|
|
print("=" * 70)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|