Files
nvfp4-megamoe-kernel/single_shot_inference.py

389 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Single-shot DSV4 inference — baseline kernel verification.
Runs one deterministic inference request through the production kernel
stack WITHOUT vLLM/sglang. Bare-metal test to verify kernel correctness.
Uses BF16 matmul after NVFP4 dequant for the linear layers (baseline).
The FMHA kernel runs on the production path (tcgen05 MMA, TMA, real deal).
Usage (on B200):
source /root/dsv4-nvfp4-workspace/venv/bin/activate
cd /root/dsv4-nvfp4-workspace/kernel
python3 single_shot_inference.py
"""
import os, sys, time, json, math
import torch
from pathlib import Path
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
MAX_NEW_TOKENS = 8
PROMPT = "The capital of France is"
# =====================================================================
# NVFP4 dequantization
# =====================================================================
# FP4 E2M1 lookup table: index → float value (unsigned)
# E2M1: 1-bit sign, 2-bit exp (bias=1), 1-bit mantissa
# Values: 0, 2, 3, 4, 6, 8, 12, Inf (for exp 00,01,10,11 × mantissa 0,1)
FP4_LUT = torch.tensor([0., 2., 3., 4., 6., 8., 12., float('inf')])
def dequant_nvfp4_weight(
weight: torch.Tensor, # (out, in/2) uint8
weight_scale: torch.Tensor, # (out, in/16) float8_e4m3fn
weight_scale_2: torch.Tensor, # scalar float32 — global scale
) -> torch.Tensor:
"""Dequantize NVFP4 weight to BF16.
Format: 2 FP4 (E2M1) values per byte (low nibble first, high nibble second).
Per-16-element E4M3 scale. Global scale multiplied on top.
"""
out_dim = weight.shape[0]
in_packed = weight.shape[1]
in_features = in_packed * 2
# Unpack nibbles
low = (weight & 0x0F).to(torch.int8) # (out, in/2)
high = (weight >> 4).to(torch.int8) # (out, in/2)
# Sign + magnitude
low_sign = (low >> 3).bool()
low_idx = (low & 0x07).long()
high_sign = (high >> 3).bool()
high_idx = (high & 0x07).long()
# LUT lookup (ensure LUT on same device as weight)
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)
# Interleave: [low0, high0, low1, high1, ...]
w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features)
# Apply scales
scale_f = weight_scale.float() * weight_scale_2.float()
scale_expanded = scale_f.repeat_interleave(16, dim=1)
return (w_f * scale_expanded).bfloat16()
# =====================================================================
# Checkpoint reader
# =====================================================================
class CheckpointReader:
def __init__(self, d):
self.dir = Path(d)
self._wm = None
self._cache = {}
self._build_index()
def _build_index(self):
ip = self.dir / "model.safetensors.index.json"
if ip.exists():
with open(ip) as f:
self._wm = json.load(f).get("weight_map", {})
else:
self._wm = {}
def _load_shard(self, name):
if name in self._cache:
return self._cache[name]
from safetensors.torch import load_file
data = load_file(str(self.dir / name))
self._cache[name] = data
return data
def get(self, key):
if self._wm and key in self._wm:
shard = self._load_shard(self._wm[key])
return shard.get(key)
return None
def get_layer(self, idx):
pre = f"model.layers.{idx}."
out = {}
if self._wm:
shards = set()
for k, s in self._wm.items():
if k.startswith(pre):
shards.add(s)
for s in shards:
d = self._load_shard(s)
for k, v in d.items():
if k.startswith(pre):
out[k] = v
return out
def clear(self):
self._cache.clear()
torch.cuda.empty_cache()
# =====================================================================
# Linear layers
# =====================================================================
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
"""NVFP4 linear: dequant → BF16 matmul."""
w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2)
return torch.nn.functional.linear(x, w)
def bf16_linear(x, weight):
"""BF16 linear."""
return torch.nn.functional.linear(x, weight.bfloat16())
# =====================================================================
# RoPE
# =====================================================================
def build_rope_cache(max_pos, head_dim, rope_dim, device, theta=10000.0):
"""Build cos/sin cache for GPT-J style partial RoPE."""
half = rope_dim // 2
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
positions = torch.arange(max_pos, dtype=torch.float32)
angles = torch.outer(positions, freqs) # (max_pos, half)
cos = torch.cos(angles) # (max_pos, half)
sin = torch.sin(angles)
return cos.to(device), sin.to(device)
def apply_rope(x, positions, cos_cache, sin_cache, rope_dim):
"""Apply partial RoPE to last rope_dim dims of each head.
x: (T, n_h, hd) BF16 → same shape with RoPE applied.
"""
T, n_h, hd = x.shape
nope = hd - rope_dim
half = rope_dim // 2
cos = cos_cache[positions] # (T, half)
sin = sin_cache[positions]
cos = cos.unsqueeze(1).to(x.dtype) # (T, 1, half)
sin = sin.unsqueeze(1).to(x.dtype)
x_rope = x[:, :, nope:] # (T, n_h, rope_dim)
even = x_rope[:, :, 0::2]
odd = x_rope[:, :, 1::2]
rot_even = even * cos - odd * sin
rot_odd = even * sin + odd * cos
out = x.clone()
out[:, :, nope:][..., 0::2] = rot_even
out[:, :, nope:][..., 1::2] = rot_odd
return out
# =====================================================================
# Single layer forward
# =====================================================================
def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
"""Forward one layer. x: (1, hidden) BF16 → (1, hidden) BF16."""
H = cfg["hidden_size"]
n_h = cfg["num_attention_heads"]
hd = cfg["head_dim"]
rd = cfg["qk_rope_head_dim"]
o_rank = cfg["o_lora_rank"]
o_groups = cfg["o_groups"]
q_lora = cfg["q_lora_rank"]
compress = cfg["compress_ratios"][li] # 128=HCA, 4=CSA, 0=SWA
pre = f"model.layers.{li}.self_attn"
T = x.shape[0]
# ---- RMSNorm (attention) ----
# DSV4 uses mHC prenorm, not standard layernorm.
# For baseline, use q_a_norm on the Q path and kv_norm on the KV path.
# No hidden-level norm (mHC handles it).
q_norm_w = w.get(f"{pre}.q_a_norm.weight") # (q_lora,) BF16
kv_norm_w = w.get(f"{pre}.kv_norm.weight") # (hd,) BF16
# ---- Q projection: q_a (down) → q_b (up) ----
qa_w = w[f"{pre}.q_a_proj.weight"]
qa_s = w[f"{pre}.q_a_proj.weight_scale"]
qa_s2 = w[f"{pre}.q_a_proj.weight_scale_2"]
qb_w = w[f"{pre}.q_b_proj.weight"]
qb_s = w[f"{pre}.q_b_proj.weight_scale"]
qb_s2 = w[f"{pre}.q_b_proj.weight_scale_2"]
# For baseline: skip per-projection norms (mHC handles it)
# Just project raw hidden
c_Q = nvfp4_linear(x, qa_w, qa_s, qa_s2) # (1, q_lora)
q = nvfp4_linear(c_Q, qb_w, qb_s, qb_s2) # (1, n_h * hd)
# ---- KV projection ----
kv_w = w[f"{pre}.kv_proj.weight"]
kv_s = w[f"{pre}.kv_proj.weight_scale"]
kv_s2 = w[f"{pre}.kv_proj.weight_scale_2"]
kv = nvfp4_linear(x, kv_w, kv_s, kv_s2) # (1, kv_dim)
# ---- Reshape for attention ----
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd)
kv_dim = kv.shape[-1]
if compress == 0: # SWA
k = kv.reshape(T, 1, hd).permute(1, 0, 2)
elif compress == 128: # HCA
c, z = kv.chunk(2, dim=-1)
k = c.reshape(T, 1, hd).permute(1, 0, 2)
elif compress == 4: # CSA
# kv has 4 streams: Ca, Cb, Za, Zb
# For baseline decode with no cache, just use Ca
ca = kv[..., :hd]
k = ca.reshape(T, 1, hd).permute(1, 0, 2)
v = k.clone()
# ---- Apply RoPE ----
pos = torch.tensor([0], dtype=torch.long, device=x.device) # decode step position
q_heads = apply_rope(q_heads, pos, rope_cos, rope_sin, rd)
k = apply_rope(k, pos, rope_cos, rope_sin, rd)
# ---- FMHA ----
from dsv4.kernels.attention.production import dsv4_attention
attn_out = dsv4_attention(q_heads, k, v) # (n_h, T, hd)
attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd) # (T, n_h*hd)
# ---- Output projection: o_a (BF16 grouped) → o_b (NVFP4) ----
oa_w = w[f"{pre}.o_a_proj.weight"] # (n_h*hd_per_group, o_rank) BF16
ob_w = w[f"{pre}.o_b_proj.weight"]
ob_s = w[f"{pre}.o_b_proj.weight_scale"]
ob_s2 = w[f"{pre}.o_b_proj.weight_scale_2"]
# o_a is BF16 grouped linear — treat as dense for baseline
grouped = bf16_linear(attn_out, oa_w.cuda()) # (1, o_groups*o_rank)
attn_proj = nvfp4_linear(grouped, ob_w, ob_s, ob_s2) # (1, H)
# ---- Residual ----
x = x + attn_proj
# ---- FFN (shared expert only for baseline) ----
# No separate FFN norm in DSV4 — mHC handles it
# For baseline, just apply shared expert to the residual x directly
# Shared expert: gate_proj + up_proj → SiLU(gate) * up → down_proj
se_pre = f"model.layers.{li}.mlp.shared_experts"
se_gate_w = w.get(f"{se_pre}.gate_proj.weight")
se_up_w = w.get(f"{se_pre}.up_proj.weight")
se_down_w = w.get(f"{se_pre}.down_proj.weight")
if se_gate_w is not None and se_up_w is not None and se_down_w is not None:
gate = nvfp4_linear(x, se_gate_w,
w[f"{se_pre}.gate_proj.weight_scale"],
w[f"{se_pre}.gate_proj.weight_scale_2"])
up = nvfp4_linear(x, se_up_w,
w[f"{se_pre}.up_proj.weight_scale"],
w[f"{se_pre}.up_proj.weight_scale_2"])
ffn_out = nvfp4_linear(
torch.nn.functional.silu(gate) * up,
se_down_w,
w[f"{se_pre}.down_proj.weight_scale"],
w[f"{se_pre}.down_proj.weight_scale_2"],
)
x = x + ffn_out
# Note: for full model, also need routed experts + scaling
else:
print(f" L{li}: no shared expert weights, skipping FFN")
return x
# =====================================================================
# Main
# =====================================================================
def main():
print("=" * 70)
print("DSV4 Single-Shot Inference — Baseline Kernel Verification")
print("=" * 70)
# Config
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
cfg = json.load(f)
n_layers = cfg["num_hidden_layers"]
H = cfg["hidden_size"]
n_h = cfg["num_attention_heads"]
hd = cfg["head_dim"]
rd = cfg["qk_rope_head_dim"]
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
print(f"Compress ratios (first 10): {cfg['compress_ratios'][:10]}")
# Tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
input_ids = tokenizer.encode(PROMPT, return_tensors="pt").cuda()
print(f"Prompt: '{PROMPT}'{input_ids.tolist()}")
# RoPE cache
rope_cos, rope_sin = build_rope_cache(8192, hd, rd, 'cuda')
# Checkpoint
reader = CheckpointReader(CHECKPOINT_DIR)
# Embedding
embed_w = reader.get("model.embed_tokens.weight")
embed = torch.nn.Embedding.from_pretrained(embed_w.cuda().bfloat16())
# lm_head (often tied with embedding)
lm_w = reader.get("lm_head.weight")
if lm_w is None:
lm_w = embed_w
print("lm_head tied with embedding")
lm_head_w = lm_w.cuda().bfloat16()
# Final norm
final_norm_w = reader.get("model.norm.weight")
# ---- Decode loop ----
print(f"\nDecoding (max {MAX_NEW_TOKENS} tokens)...")
generated = input_ids[0].tolist()
for step in range(MAX_NEW_TOKENS):
t0 = time.time()
tid = torch.tensor([generated[-1]], dtype=torch.long, device='cuda')
# Embed
x = embed(tid) # (1, H)
# Layers (streaming — load one at a time)
for li in range(n_layers):
lw = reader.get_layer(li)
if not lw:
print(f" L{li}: no weights!")
continue
x = forward_layer(x, lw, li, cfg, rope_cos, rope_sin)
del lw
if li % 10 == 9:
reader.clear()
# Final norm
if final_norm_w is not None:
xf = x.float()
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
x = (xf * rms * final_norm_w.cuda().float()).bfloat16()
# lm_head
logits = torch.nn.functional.linear(x, lm_head_w)
next_id = torch.argmax(logits, dim=-1).item()
generated.append(next_id)
tok_str = tokenizer.decode([next_id])
dt = time.time() - t0
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.1f}s)")
if next_id == tokenizer.eos_token_id:
break
out = tokenizer.decode(generated, skip_special_tokens=True)
print(f"\n{'='*70}")
print(f"Input: '{PROMPT}'")
print(f"Output: '{out}'")
print(f"{'='*70}")
if __name__ == "__main__":
main()