389 lines
13 KiB
Python
389 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
"""Single-shot DSV4 inference — baseline kernel verification.
|
||
|
||
Runs one deterministic inference request through the production kernel
|
||
stack WITHOUT vLLM/sglang. Bare-metal test to verify kernel correctness.
|
||
|
||
Uses BF16 matmul after NVFP4 dequant for the linear layers (baseline).
|
||
The FMHA kernel runs on the production path (tcgen05 MMA, TMA, real deal).
|
||
|
||
Usage (on B200):
|
||
source /root/dsv4-nvfp4-workspace/venv/bin/activate
|
||
cd /root/dsv4-nvfp4-workspace/kernel
|
||
python3 single_shot_inference.py
|
||
"""
|
||
import os, sys, time, json, math
|
||
import torch
|
||
from pathlib import Path
|
||
|
||
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||
MAX_NEW_TOKENS = 8
|
||
PROMPT = "The capital of France is"
|
||
|
||
|
||
# =====================================================================
|
||
# NVFP4 dequantization
|
||
# =====================================================================
|
||
|
||
# FP4 E2M1 lookup table: index → float value (unsigned)
|
||
# E2M1: 1-bit sign, 2-bit exp (bias=1), 1-bit mantissa
|
||
# Values: 0, 2, 3, 4, 6, 8, 12, Inf (for exp 00,01,10,11 × mantissa 0,1)
|
||
FP4_LUT = torch.tensor([0., 2., 3., 4., 6., 8., 12., float('inf')])
|
||
|
||
def dequant_nvfp4_weight(
|
||
weight: torch.Tensor, # (out, in/2) uint8
|
||
weight_scale: torch.Tensor, # (out, in/16) float8_e4m3fn
|
||
weight_scale_2: torch.Tensor, # scalar float32 — global scale
|
||
) -> torch.Tensor:
|
||
"""Dequantize NVFP4 weight to BF16.
|
||
|
||
Format: 2 FP4 (E2M1) values per byte (low nibble first, high nibble second).
|
||
Per-16-element E4M3 scale. Global scale multiplied on top.
|
||
"""
|
||
out_dim = weight.shape[0]
|
||
in_packed = weight.shape[1]
|
||
in_features = in_packed * 2
|
||
|
||
# Unpack nibbles
|
||
low = (weight & 0x0F).to(torch.int8) # (out, in/2)
|
||
high = (weight >> 4).to(torch.int8) # (out, in/2)
|
||
|
||
# Sign + magnitude
|
||
low_sign = (low >> 3).bool()
|
||
low_idx = (low & 0x07).long()
|
||
high_sign = (high >> 3).bool()
|
||
high_idx = (high & 0x07).long()
|
||
|
||
# LUT lookup (ensure LUT on same device as weight)
|
||
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
|
||
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
|
||
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)
|
||
|
||
# Interleave: [low0, high0, low1, high1, ...]
|
||
w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features)
|
||
|
||
# Apply scales
|
||
scale_f = weight_scale.float() * weight_scale_2.float()
|
||
scale_expanded = scale_f.repeat_interleave(16, dim=1)
|
||
|
||
return (w_f * scale_expanded).bfloat16()
|
||
|
||
|
||
# =====================================================================
|
||
# Checkpoint reader
|
||
# =====================================================================
|
||
|
||
class CheckpointReader:
|
||
def __init__(self, d):
|
||
self.dir = Path(d)
|
||
self._wm = None
|
||
self._cache = {}
|
||
self._build_index()
|
||
|
||
def _build_index(self):
|
||
ip = self.dir / "model.safetensors.index.json"
|
||
if ip.exists():
|
||
with open(ip) as f:
|
||
self._wm = json.load(f).get("weight_map", {})
|
||
else:
|
||
self._wm = {}
|
||
|
||
def _load_shard(self, name):
|
||
if name in self._cache:
|
||
return self._cache[name]
|
||
from safetensors.torch import load_file
|
||
data = load_file(str(self.dir / name))
|
||
self._cache[name] = data
|
||
return data
|
||
|
||
def get(self, key):
|
||
if self._wm and key in self._wm:
|
||
shard = self._load_shard(self._wm[key])
|
||
return shard.get(key)
|
||
return None
|
||
|
||
def get_layer(self, idx):
|
||
pre = f"model.layers.{idx}."
|
||
out = {}
|
||
if self._wm:
|
||
shards = set()
|
||
for k, s in self._wm.items():
|
||
if k.startswith(pre):
|
||
shards.add(s)
|
||
for s in shards:
|
||
d = self._load_shard(s)
|
||
for k, v in d.items():
|
||
if k.startswith(pre):
|
||
out[k] = v
|
||
return out
|
||
|
||
def clear(self):
|
||
self._cache.clear()
|
||
torch.cuda.empty_cache()
|
||
|
||
|
||
# =====================================================================
|
||
# Linear layers
|
||
# =====================================================================
|
||
|
||
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
|
||
"""NVFP4 linear: dequant → BF16 matmul."""
|
||
w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2)
|
||
return torch.nn.functional.linear(x, w)
|
||
|
||
def bf16_linear(x, weight):
|
||
"""BF16 linear."""
|
||
return torch.nn.functional.linear(x, weight.bfloat16())
|
||
|
||
|
||
# =====================================================================
|
||
# RoPE
|
||
# =====================================================================
|
||
|
||
def build_rope_cache(max_pos, head_dim, rope_dim, device, theta=10000.0):
|
||
"""Build cos/sin cache for GPT-J style partial RoPE."""
|
||
half = rope_dim // 2
|
||
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
|
||
positions = torch.arange(max_pos, dtype=torch.float32)
|
||
angles = torch.outer(positions, freqs) # (max_pos, half)
|
||
cos = torch.cos(angles) # (max_pos, half)
|
||
sin = torch.sin(angles)
|
||
return cos.to(device), sin.to(device)
|
||
|
||
|
||
def apply_rope(x, positions, cos_cache, sin_cache, rope_dim):
|
||
"""Apply partial RoPE to last rope_dim dims of each head.
|
||
x: (T, n_h, hd) BF16 → same shape with RoPE applied.
|
||
"""
|
||
T, n_h, hd = x.shape
|
||
nope = hd - rope_dim
|
||
half = rope_dim // 2
|
||
|
||
cos = cos_cache[positions] # (T, half)
|
||
sin = sin_cache[positions]
|
||
cos = cos.unsqueeze(1).to(x.dtype) # (T, 1, half)
|
||
sin = sin.unsqueeze(1).to(x.dtype)
|
||
|
||
x_rope = x[:, :, nope:] # (T, n_h, rope_dim)
|
||
even = x_rope[:, :, 0::2]
|
||
odd = x_rope[:, :, 1::2]
|
||
|
||
rot_even = even * cos - odd * sin
|
||
rot_odd = even * sin + odd * cos
|
||
|
||
out = x.clone()
|
||
out[:, :, nope:][..., 0::2] = rot_even
|
||
out[:, :, nope:][..., 1::2] = rot_odd
|
||
return out
|
||
|
||
|
||
# =====================================================================
|
||
# Single layer forward
|
||
# =====================================================================
|
||
|
||
def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
|
||
"""Forward one layer. x: (1, hidden) BF16 → (1, hidden) BF16."""
|
||
H = cfg["hidden_size"]
|
||
n_h = cfg["num_attention_heads"]
|
||
hd = cfg["head_dim"]
|
||
rd = cfg["qk_rope_head_dim"]
|
||
o_rank = cfg["o_lora_rank"]
|
||
o_groups = cfg["o_groups"]
|
||
q_lora = cfg["q_lora_rank"]
|
||
compress = cfg["compress_ratios"][li] # 128=HCA, 4=CSA, 0=SWA
|
||
|
||
pre = f"model.layers.{li}.self_attn"
|
||
T = x.shape[0]
|
||
|
||
# ---- RMSNorm (attention) ----
|
||
# DSV4 uses mHC prenorm, not standard layernorm.
|
||
# For baseline, use q_a_norm on the Q path and kv_norm on the KV path.
|
||
# No hidden-level norm (mHC handles it).
|
||
q_norm_w = w.get(f"{pre}.q_a_norm.weight") # (q_lora,) BF16
|
||
kv_norm_w = w.get(f"{pre}.kv_norm.weight") # (hd,) BF16
|
||
|
||
# ---- Q projection: q_a (down) → q_b (up) ----
|
||
qa_w = w[f"{pre}.q_a_proj.weight"]
|
||
qa_s = w[f"{pre}.q_a_proj.weight_scale"]
|
||
qa_s2 = w[f"{pre}.q_a_proj.weight_scale_2"]
|
||
qb_w = w[f"{pre}.q_b_proj.weight"]
|
||
qb_s = w[f"{pre}.q_b_proj.weight_scale"]
|
||
qb_s2 = w[f"{pre}.q_b_proj.weight_scale_2"]
|
||
|
||
# For baseline: skip per-projection norms (mHC handles it)
|
||
# Just project raw hidden
|
||
c_Q = nvfp4_linear(x, qa_w, qa_s, qa_s2) # (1, q_lora)
|
||
q = nvfp4_linear(c_Q, qb_w, qb_s, qb_s2) # (1, n_h * hd)
|
||
|
||
# ---- KV projection ----
|
||
kv_w = w[f"{pre}.kv_proj.weight"]
|
||
kv_s = w[f"{pre}.kv_proj.weight_scale"]
|
||
kv_s2 = w[f"{pre}.kv_proj.weight_scale_2"]
|
||
kv = nvfp4_linear(x, kv_w, kv_s, kv_s2) # (1, kv_dim)
|
||
|
||
# ---- Reshape for attention ----
|
||
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd)
|
||
|
||
kv_dim = kv.shape[-1]
|
||
if compress == 0: # SWA
|
||
k = kv.reshape(T, 1, hd).permute(1, 0, 2)
|
||
elif compress == 128: # HCA
|
||
c, z = kv.chunk(2, dim=-1)
|
||
k = c.reshape(T, 1, hd).permute(1, 0, 2)
|
||
elif compress == 4: # CSA
|
||
# kv has 4 streams: Ca, Cb, Za, Zb
|
||
# For baseline decode with no cache, just use Ca
|
||
ca = kv[..., :hd]
|
||
k = ca.reshape(T, 1, hd).permute(1, 0, 2)
|
||
v = k.clone()
|
||
|
||
# ---- Apply RoPE ----
|
||
pos = torch.tensor([0], dtype=torch.long, device=x.device) # decode step position
|
||
q_heads = apply_rope(q_heads, pos, rope_cos, rope_sin, rd)
|
||
k = apply_rope(k, pos, rope_cos, rope_sin, rd)
|
||
|
||
# ---- FMHA ----
|
||
from dsv4.kernels.attention.production import dsv4_attention
|
||
attn_out = dsv4_attention(q_heads, k, v) # (n_h, T, hd)
|
||
attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd) # (T, n_h*hd)
|
||
|
||
# ---- Output projection: o_a (BF16 grouped) → o_b (NVFP4) ----
|
||
oa_w = w[f"{pre}.o_a_proj.weight"] # (n_h*hd_per_group, o_rank) BF16
|
||
ob_w = w[f"{pre}.o_b_proj.weight"]
|
||
ob_s = w[f"{pre}.o_b_proj.weight_scale"]
|
||
ob_s2 = w[f"{pre}.o_b_proj.weight_scale_2"]
|
||
|
||
# o_a is BF16 grouped linear — treat as dense for baseline
|
||
grouped = bf16_linear(attn_out, oa_w.cuda()) # (1, o_groups*o_rank)
|
||
attn_proj = nvfp4_linear(grouped, ob_w, ob_s, ob_s2) # (1, H)
|
||
|
||
# ---- Residual ----
|
||
x = x + attn_proj
|
||
|
||
# ---- FFN (shared expert only for baseline) ----
|
||
# No separate FFN norm in DSV4 — mHC handles it
|
||
# For baseline, just apply shared expert to the residual x directly
|
||
|
||
# Shared expert: gate_proj + up_proj → SiLU(gate) * up → down_proj
|
||
se_pre = f"model.layers.{li}.mlp.shared_experts"
|
||
se_gate_w = w.get(f"{se_pre}.gate_proj.weight")
|
||
se_up_w = w.get(f"{se_pre}.up_proj.weight")
|
||
se_down_w = w.get(f"{se_pre}.down_proj.weight")
|
||
|
||
if se_gate_w is not None and se_up_w is not None and se_down_w is not None:
|
||
gate = nvfp4_linear(x, se_gate_w,
|
||
w[f"{se_pre}.gate_proj.weight_scale"],
|
||
w[f"{se_pre}.gate_proj.weight_scale_2"])
|
||
up = nvfp4_linear(x, se_up_w,
|
||
w[f"{se_pre}.up_proj.weight_scale"],
|
||
w[f"{se_pre}.up_proj.weight_scale_2"])
|
||
ffn_out = nvfp4_linear(
|
||
torch.nn.functional.silu(gate) * up,
|
||
se_down_w,
|
||
w[f"{se_pre}.down_proj.weight_scale"],
|
||
w[f"{se_pre}.down_proj.weight_scale_2"],
|
||
)
|
||
x = x + ffn_out
|
||
# Note: for full model, also need routed experts + scaling
|
||
else:
|
||
print(f" L{li}: no shared expert weights, skipping FFN")
|
||
|
||
return x
|
||
|
||
|
||
# =====================================================================
|
||
# Main
|
||
# =====================================================================
|
||
|
||
def main():
|
||
print("=" * 70)
|
||
print("DSV4 Single-Shot Inference — Baseline Kernel Verification")
|
||
print("=" * 70)
|
||
|
||
# Config
|
||
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
||
cfg = json.load(f)
|
||
n_layers = cfg["num_hidden_layers"]
|
||
H = cfg["hidden_size"]
|
||
n_h = cfg["num_attention_heads"]
|
||
hd = cfg["head_dim"]
|
||
rd = cfg["qk_rope_head_dim"]
|
||
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
|
||
print(f"Compress ratios (first 10): {cfg['compress_ratios'][:10]}")
|
||
|
||
# Tokenizer
|
||
from transformers import AutoTokenizer
|
||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||
input_ids = tokenizer.encode(PROMPT, return_tensors="pt").cuda()
|
||
print(f"Prompt: '{PROMPT}' → {input_ids.tolist()}")
|
||
|
||
# RoPE cache
|
||
rope_cos, rope_sin = build_rope_cache(8192, hd, rd, 'cuda')
|
||
|
||
# Checkpoint
|
||
reader = CheckpointReader(CHECKPOINT_DIR)
|
||
|
||
# Embedding
|
||
embed_w = reader.get("model.embed_tokens.weight")
|
||
embed = torch.nn.Embedding.from_pretrained(embed_w.cuda().bfloat16())
|
||
|
||
# lm_head (often tied with embedding)
|
||
lm_w = reader.get("lm_head.weight")
|
||
if lm_w is None:
|
||
lm_w = embed_w
|
||
print("lm_head tied with embedding")
|
||
lm_head_w = lm_w.cuda().bfloat16()
|
||
|
||
# Final norm
|
||
final_norm_w = reader.get("model.norm.weight")
|
||
|
||
# ---- Decode loop ----
|
||
print(f"\nDecoding (max {MAX_NEW_TOKENS} tokens)...")
|
||
generated = input_ids[0].tolist()
|
||
|
||
for step in range(MAX_NEW_TOKENS):
|
||
t0 = time.time()
|
||
tid = torch.tensor([generated[-1]], dtype=torch.long, device='cuda')
|
||
|
||
# Embed
|
||
x = embed(tid) # (1, H)
|
||
|
||
# Layers (streaming — load one at a time)
|
||
for li in range(n_layers):
|
||
lw = reader.get_layer(li)
|
||
if not lw:
|
||
print(f" L{li}: no weights!")
|
||
continue
|
||
x = forward_layer(x, lw, li, cfg, rope_cos, rope_sin)
|
||
del lw
|
||
if li % 10 == 9:
|
||
reader.clear()
|
||
|
||
# Final norm
|
||
if final_norm_w is not None:
|
||
xf = x.float()
|
||
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||
x = (xf * rms * final_norm_w.cuda().float()).bfloat16()
|
||
|
||
# lm_head
|
||
logits = torch.nn.functional.linear(x, lm_head_w)
|
||
next_id = torch.argmax(logits, dim=-1).item()
|
||
generated.append(next_id)
|
||
|
||
tok_str = tokenizer.decode([next_id])
|
||
dt = time.time() - t0
|
||
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.1f}s)")
|
||
|
||
if next_id == tokenizer.eos_token_id:
|
||
break
|
||
|
||
out = tokenizer.decode(generated, skip_special_tokens=True)
|
||
print(f"\n{'='*70}")
|
||
print(f"Input: '{PROMPT}'")
|
||
print(f"Output: '{out}'")
|
||
print(f"{'='*70}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|