- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py - Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc. - Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda) - Moved PyTorch bridges to dsv4/ops/ - Moved nn.Module layers to dsv4layers/ - Moved reference implementations to dsv4/reference/ - Moved vendored CUTLASS code to vendored/ - Archived ~190 debug tests to tests/archive/ - Kept ~15 canonical tests in tests/unit/ - Updated all import paths - Added stubs for future components (model/, cache/, loader/) - Updated pyproject.toml: dsv4-inference package name
288 lines
12 KiB
Python
288 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Full DeepSeek-V4 attention pipeline test with real weights.
|
|
|
|
Architecture (NOT MLA — CSA/HCA):
|
|
1. q_a_proj (7168→1536) + kv_proj (7168→512) — NVFP4 CuTeDSL
|
|
2. q_norm + kv_norm — RMS
|
|
3. q_b_proj (1536→65536) — NVFP4 CuTeDSL
|
|
4. RoPE on Q (GPT-J, 64 dims)
|
|
5. SWA attention (sliding window=128, causal, SDPA) — BF16
|
|
6. o_a: inverse RoPE + BMM with (16, 1024, 8192) — BF16
|
|
7. o_b: (T, 16384→7168) — NVFP4 CuTeDSL
|
|
|
|
For CSA/HCA layers, step 5 would be sparse attention with indexed positions.
|
|
This test uses SWA-only (layer 60, compress_ratio=0) and C128A (layer 0)
|
|
to test both paths.
|
|
|
|
Usage (on B200):
|
|
cd /root/nvfp4-megamoe-kernel
|
|
PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_v4_attention_b200.py
|
|
"""
|
|
|
|
import sys, os, json, torch, torch.nn.functional as F, math
|
|
from safetensors import safe_open
|
|
|
|
REPO = "/root/nvfp4-megamoe-kernel"
|
|
sys.path.insert(0, REPO)
|
|
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
DEV = "cuda:0"
|
|
|
|
# Model config
|
|
H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64
|
|
QL = 1536; OL = 1024; OG = 16; HPG = NH // OG
|
|
EPS = 1e-6; WINDOW = 128; SCALE = HD ** -0.5
|
|
|
|
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
|
|
|
|
_cache = {}
|
|
def P(k, wm, md):
|
|
if k in _cache: return _cache[k]
|
|
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
|
|
t = f.get_tensor(k)
|
|
_cache[k] = t
|
|
return t
|
|
|
|
def dequant(w, sf, gs):
|
|
d = w.device; lut = E2M1.to(d)
|
|
lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()]
|
|
O, I2 = w.shape; I = I2*2
|
|
u = torch.empty(O, I, dtype=torch.float32, device=d)
|
|
u[:,0::2] = lo; u[:,1::2] = hi
|
|
bs = sf.float().repeat_interleave(16, dim=1)[:O,:I]
|
|
return (u * bs * gs).to(torch.bfloat16)
|
|
|
|
def rms(x, w, eps=1e-6):
|
|
v = x.float().pow(2).mean(-1, keepdim=True)
|
|
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
|
|
|
|
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
|
|
from dsv4.layers.linear import Nvfp4Linear
|
|
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
|
|
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
|
|
s = s.permute(1,0).contiguous()
|
|
if fused and gs_t.numel() == 2:
|
|
g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2)
|
|
if g1 != g2:
|
|
s32 = s.float(); sp = lw[0] if lw else outf//2
|
|
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
|
|
else:
|
|
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
|
|
r = Nvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
|
|
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
|
|
r.finalize_weights(); r._ensure_initialized()
|
|
return r
|
|
|
|
def apply_gptj_rope(x, positions, cos_sin, nope, rope):
|
|
if rope == 0 or x.numel() == 0: return x
|
|
half = rope // 2
|
|
cos = cos_sin[positions, :half].to(x.dtype)
|
|
sin = cos_sin[positions, half:2*half].to(x.dtype)
|
|
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
|
|
x_rope = x[..., nope:].clone()
|
|
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
|
|
out = x.clone()
|
|
out[..., nope:][..., 0::2] = even * cos - odd * sin
|
|
out[..., nope:][..., 1::2] = even * sin + odd * cos
|
|
return out
|
|
|
|
def apply_inv_gptj_rope(x, positions, cos_sin, nope, rope):
|
|
"""Inverse RoPE: same as forward but sin → -sin."""
|
|
if rope == 0 or x.numel() == 0: return x
|
|
half = rope // 2
|
|
cos = cos_sin[positions, :half].to(x.dtype)
|
|
sin = cos_sin[positions, half:2*half].to(x.dtype)
|
|
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
|
|
x_rope = x[..., nope:].clone()
|
|
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
|
|
out = x.clone()
|
|
out[..., nope:][..., 0::2] = even * cos + odd * sin
|
|
out[..., nope:][..., 1::2] = -even * sin + odd * cos
|
|
return out
|
|
|
|
def build_cos_sin(max_pos=4096, rope_dim=ROPE):
|
|
half = rope_dim // 2
|
|
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half))
|
|
freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq)
|
|
return torch.cat([freqs.cos(), freqs.sin()], dim=-1)
|
|
|
|
|
|
def swa_attention(q, kv, scale, window_size=WINDOW):
|
|
"""Sliding window attention using SDPA.
|
|
|
|
q: (T, NH, HD) with RoPE
|
|
kv: (T, HD) shared KV latent
|
|
For SWA: attend to last window_size tokens only.
|
|
"""
|
|
T, NH, HD = q.shape
|
|
if T <= window_size:
|
|
# Full attention within window
|
|
return full_causal_attention(q, kv, scale)
|
|
|
|
# For long sequences, only attend to window
|
|
# This is a simplified version — production would use paged cache
|
|
q_2d = q.reshape(T * NH, HD)
|
|
kv_exp = kv.unsqueeze(1).expand(-1, NH, -1).contiguous()
|
|
k_2d = kv_exp.permute(1, 0, 2).unsqueeze(1).expand(NH, T, T, -1).contiguous().reshape(T * NH, T, HD)
|
|
v_2d = k_2d.clone()
|
|
scores = torch.matmul(q_2d.unsqueeze(1), k_2d.transpose(-1, -2)) * scale
|
|
query_pos = torch.arange(T, device=q.device).unsqueeze(1).repeat(1, NH).reshape(T * NH)
|
|
kv_pos = torch.arange(T, device=q.device).unsqueeze(0)
|
|
causal = kv_pos <= query_pos.unsqueeze(1)
|
|
window = kv_pos >= (query_pos.unsqueeze(1) - window_size + 1)
|
|
mask = causal & window
|
|
scores = scores.squeeze(1).masked_fill(~mask, float('-inf'))
|
|
weights = F.softmax(scores.float(), dim=-1).to(q.dtype)
|
|
out = torch.matmul(weights.unsqueeze(1), v_2d).squeeze(1)
|
|
return out.reshape(T, NH, HD)
|
|
|
|
|
|
def full_causal_attention(q, kv, scale):
|
|
"""Full causal self-attention (for testing with T <= window_size)."""
|
|
T, NH, HD = q.shape
|
|
q_2d = q.reshape(T * NH, HD)
|
|
kv_exp = kv.unsqueeze(1).expand(-1, NH, -1).contiguous()
|
|
k_2d = kv_exp.permute(1, 0, 2).unsqueeze(1).expand(NH, T, T, -1).contiguous().reshape(T * NH, T, HD)
|
|
v_2d = k_2d.clone()
|
|
scores = torch.matmul(q_2d.unsqueeze(1), k_2d.transpose(-1, -2)) * scale
|
|
query_pos = torch.arange(T, device=q.device).unsqueeze(1).repeat(1, NH).reshape(T * NH)
|
|
kv_pos = torch.arange(T, device=q.device).unsqueeze(0)
|
|
causal = kv_pos <= query_pos.unsqueeze(1)
|
|
scores = scores.squeeze(1).masked_fill(~causal, float('-inf'))
|
|
weights = F.softmax(scores.float(), dim=-1).to(q.dtype)
|
|
out = torch.matmul(weights.unsqueeze(1), v_2d).squeeze(1)
|
|
return out.reshape(T, NH, HD)
|
|
|
|
|
|
def test_layer(layer_id, compress_ratio):
|
|
"""Test the full attention pipeline for a specific layer."""
|
|
torch.cuda.set_device(0)
|
|
torch.manual_seed(42)
|
|
torch.cuda.empty_cache()
|
|
|
|
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
|
|
wm = json.load(f)["weight_map"]
|
|
G = lambda k: P(k, wm, MODEL).to(DEV)
|
|
|
|
p = f"model.layers.{layer_id}"; a = f"{p}.self_attn"
|
|
layer_type = "SWA" if compress_ratio <= 1 else f"CSA(c={compress_ratio})"
|
|
|
|
print(f"\n{'='*70}")
|
|
print(f" Layer {layer_id} — {layer_type}")
|
|
print(f"{'='*70}")
|
|
|
|
# Load weights
|
|
emb = G("model.embed_tokens.weight")
|
|
anorm = G(f"{p}.input_layernorm.weight")
|
|
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
|
|
woa = G(f"{a}.o_a_proj.weight") # (16384, 8192) BF16
|
|
|
|
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
|
|
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
|
|
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
|
|
wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
|
|
sinks = G(f"{a}.sinks")
|
|
|
|
# BF16 references
|
|
qa_bf16 = dequant(qa_w, qa_sf, qa_gs.item())
|
|
qb_bf16 = dequant(qb_w, qb_sf, qb_gs.item())
|
|
kv_bf16 = dequant(kv_w, kv_sf, kv_gs.item())
|
|
wob_bf16 = dequant(wob_w, wob_sf, wob_gs.item())
|
|
|
|
# CuTeDSL runners
|
|
r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0])
|
|
r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0])
|
|
r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0])
|
|
r_wob = make_runner(wob_w, wob_sf, wob_gs, OG*OL, wob_w.shape[0])
|
|
|
|
# Input
|
|
NT = 6
|
|
token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374], dtype=torch.long, device=DEV)
|
|
cos_sin = build_cos_sin(max_pos=WINDOW + 256).to(DEV)
|
|
positions = torch.arange(NT, dtype=torch.int64, device=DEV)
|
|
|
|
with torch.no_grad():
|
|
hidden = emb[token_ids]
|
|
normed = rms(hidden, anorm, EPS)
|
|
|
|
# ── CuTeDSL path ─────────────────────────────────────────────
|
|
qa_cute = r_qa.run(normed)
|
|
kv_cute = r_kv.run(normed)
|
|
qa_n = rms(qa_cute, qn, EPS)
|
|
kv_n = rms(kv_cute, kvn, EPS)
|
|
q_cute = r_qb.run(qa_n).view(NT, NH, HD)
|
|
q_rope = apply_gptj_rope(q_cute, positions, cos_sin, NOPE, ROPE)
|
|
|
|
# SWA attention (for T=6, full causal within window)
|
|
o_attn = full_causal_attention(q_rope, kv_n, SCALE)
|
|
|
|
# o_a: inverse RoPE + BMM
|
|
o_inv = apply_inv_gptj_rope(o_attn, positions, cos_sin, NOPE, ROPE)
|
|
o_grouped = o_inv.view(NT, OG, HPG * HD).permute(1, 0, 2)
|
|
woa_3d = woa.view(OG, OL, HPG * HD)
|
|
z_cute = torch.bmm(o_grouped, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(NT, OG * OL)
|
|
|
|
# o_b
|
|
attn_out = r_wob.run(z_cute)
|
|
|
|
# ── BF16 reference ───────────────────────────────────────────
|
|
qa_bf = normed @ qa_bf16.T
|
|
kv_bf = normed @ kv_bf16.T
|
|
qa_n_bf = rms(qa_bf, qn, EPS)
|
|
kv_n_bf = rms(kv_bf, kvn, EPS)
|
|
q_bf = (qa_n_bf @ qb_bf16.T).view(NT, NH, HD)
|
|
q_rope_bf = apply_gptj_rope(q_bf, positions, cos_sin, NOPE, ROPE)
|
|
o_attn_bf = full_causal_attention(q_rope_bf, kv_n_bf, SCALE)
|
|
o_inv_bf = apply_inv_gptj_rope(o_attn_bf, positions, cos_sin, NOPE, ROPE)
|
|
o_grouped_bf = o_inv_bf.view(NT, OG, HPG * HD).permute(1, 0, 2)
|
|
z_bf = torch.bmm(o_grouped_bf, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(NT, OG * OL)
|
|
attn_bf = z_bf @ wob_bf16.T
|
|
|
|
# ── Compare ──────────────────────────────────────────────────
|
|
c = F.cosine_similarity(attn_out.flatten().unsqueeze(0).float(), attn_bf.flatten().unsqueeze(0).float()).item()
|
|
print(f" CuTeDSL vs BF16 cosine: {c:.6f} {'✅' if c>=0.95 else '❌'}")
|
|
print(f" CuTeDSL amax: {attn_out.amax():.4f} BF16 amax: {attn_bf.amax():.4f}")
|
|
|
|
# Full forward: attention → residual → norm → LM head
|
|
fnorm_w = G("model.norm.weight")
|
|
lm_head = G("lm_head.weight")
|
|
x = hidden + attn_out
|
|
x_normed = rms(x, fnorm_w, EPS)
|
|
logits = x_normed @ lm_head.T
|
|
top5 = torch.topk(logits[-1], 5)
|
|
log_std = logits[-1].float().std().item()
|
|
print(f" logits: amax={logits.amax():.4f} std={log_std:.4f} top5={top5.indices.tolist()}")
|
|
print(f" logit check: {'✅' if 0.5 < log_std < 50 else '❌'} (0.5 < std < 50)")
|
|
|
|
# Cleanup
|
|
del r_qa, r_qb, r_kv, r_wob
|
|
torch.cuda.empty_cache()
|
|
return c
|
|
|
|
|
|
def main():
|
|
print("=" * 70)
|
|
print(" DeepSeek-V4 CSA/HCA Attention Pipeline Test")
|
|
print(" (NOT MLA — Compressed Sparse Attention)")
|
|
print("=" * 70)
|
|
|
|
# Test SWA layer (layer 60, compress_ratio=0)
|
|
c_swa = test_layer(60, 0)
|
|
|
|
# Test C128A layer (layer 0, compress_ratio=128)
|
|
c_c128 = test_layer(0, 128)
|
|
|
|
# Test C4A layer (layer 2, compress_ratio=4)
|
|
c_c4 = test_layer(2, 4)
|
|
|
|
print(f"\n{'='*70}")
|
|
print(f" SUMMARY")
|
|
print(f" Layer 60 (SWA): {c_swa:.6f} {'✅' if c_swa>=0.95 else '❌'}")
|
|
print(f" Layer 0 (C128A/HCA): {c_c128:.6f} {'✅' if c_c128>=0.95 else '❌'}")
|
|
print(f" Layer 2 (C4A/CSA): {c_c4:.6f} {'✅' if c_c4>=0.95 else '❌'}")
|
|
print(f"{'='*70}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|