Add single_shot_inference.py — baseline kernel verification

Streams weights one layer at a time from 95 safetensors shards.
NVFP4 dequant → BF16 matmul for baseline (production uses tcgen05 MMA).
Runs token-by-token decode loop with production FMHA kernel.

Known gaps for first run:
- FFN (MoE) skipped — not the kernel under test
- mHC simplified — not the kernel under test
- RoPE skipped in baseline
- compressor/indexer bypassed (raw KV for now)

FMHA kernel is the component under test (cos ≥ 0.999993).
This commit is contained in:
2026-05-30 22:39:01 +00:00
parent 4472928506
commit 9b0858aa35

501
single_shot_inference.py Normal file
View File

@@ -0,0 +1,501 @@
#!/usr/bin/env python3
"""Single-shot DSV4 inference — baseline kernel verification.
Runs one deterministic inference request through the production kernel
stack WITHOUT vLLM/sglang. This is a bare-metal test to verify kernel
correctness end-to-end.
Usage (on B200):
source /root/dsv4-nvfp4-workspace/venv/bin/activate
cd /root/dsv4-nvfp4-workspace/kernel
python single_shot_inference.py
Design:
- Loads weights one layer at a time (streaming, ~15GB peak)
- Runs decode loop: token-by-token autoregressive generation
- Uses the production FMHA kernel via dsv4_attention
- Verifies against expected output ("Paris" for "The capital of France is")
- No vLLM, no sglang, no serving framework — just the kernel
"""
import os
import sys
import time
import torch
import json
from pathlib import Path
# ---- Paths ----
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
VENV = "/root/dsv4-nvfp4-workspace/venv"
# ---- Config ----
MAX_NEW_TOKENS = 10
PROMPT = "The capital of France is"
# =====================================================================
# Weight loading — stream from safetensors shards
# =====================================================================
class CheckpointReader:
"""Lazy reader for DSV4 safetensors shards.
Instead of loading all 95 shards (945GB), provides per-layer access
by loading individual shards and extracting the relevant keys.
"""
def __init__(self, checkpoint_dir: str):
self.dir = Path(checkpoint_dir)
self._index = None
self._shard_cache = {} # shard_idx -> dict
self._weight_map = None # key -> shard_idx
self._build_index()
def _build_index(self):
"""Build the weight→shard mapping from the model.safetensors.index.json."""
index_path = self.dir / "model.safetensors.index.json"
if index_path.exists():
with open(index_path) as f:
idx = json.load(f)
self._weight_map = idx.get("weight_map", {})
else:
# No index — load all shards (will be slow, but works)
self._weight_map = {}
print("WARNING: No index file found, will scan all shards")
def _load_shard(self, shard_name: str):
"""Load a single shard file."""
if shard_name in self._shard_cache:
return self._shard_cache[shard_name]
path = self.dir / shard_name
if not path.exists():
return None
from safetensors.torch import load_file
print(f" Loading shard: {shard_name}")
data = load_file(str(path))
self._shard_cache[shard_name] = data
return data
def get_weight(self, key: str):
"""Get a single weight tensor by key."""
if self._weight_map and key in self._weight_map:
shard_name = self._weight_map[key]
shard = self._load_shard(shard_name)
if shard and key in shard:
return shard[key]
# Fallback: scan all shards
for i in range(1, 96):
shard_name = f"model-{i:05d}-of-00095.safetensors"
shard = self._load_shard(shard_name)
if shard and key in shard:
return shard[key]
return None
def get_layer_weights(self, layer_idx: int):
"""Get all weights for a single layer."""
prefix = f"model.layers.{layer_idx}."
weights = {}
if self._weight_map:
# Find which shards contain this layer
shard_names = set()
for key, shard in self._weight_map.items():
if key.startswith(prefix):
shard_names.add(shard)
for shard_name in shard_names:
shard = self._load_shard(shard_name)
if shard:
for key, value in shard.items():
if key.startswith(prefix):
weights[key] = value
else:
# Scan all shards
for i in range(1, 96):
shard_name = f"model-{i:05d}-of-00095.safetensors"
shard = self._load_shard(shard_name)
if shard:
for key, value in shard.items():
if key.startswith(prefix):
weights[key] = value
return weights
def clear_cache(self):
"""Free cached shard data."""
self._shard_cache.clear()
torch.cuda.empty_cache()
# =====================================================================
# Tokenizer — simple BPE via transformers
# =====================================================================
def load_tokenizer():
"""Load the DSV4 tokenizer."""
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
# =====================================================================
# Model config from checkpoint
# =====================================================================
def load_config():
"""Load model config from checkpoint."""
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
return json.load(f)
# =====================================================================
# NVFP4 Linear — weight loading + forward
# =====================================================================
class NVFP4Linear:
"""NVFP4 quantized linear layer.
Stores weight as uint8 (2 FP4 per byte) + E4M3 per-16-element scale + global_scale.
Forward: dequantize → BF16 matmul (for baseline; production uses tcgen05 MMA).
"""
def __init__(self, in_features: int, out_features: int):
self.in_features = in_features
self.out_features = out_features
self.weight = None # (out, in/2) uint8 — packed FP4
self.weight_scale = None # (out, in/16) float8_e4m3fn — per-16 scale
self.global_scale = None # scalar float32
self._bias = None
def load_from_checkpoint(self, weight: torch.Tensor, weight_scale: torch.Tensor,
global_scale: torch.Tensor = None):
"""Load from checkpoint tensors."""
self.weight = weight.cuda()
self.weight_scale = weight_scale.cuda()
if global_scale is not None:
self.global_scale = global_scale.cuda()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass: dequantize → BF16 → matmul.
This is the BASELINE path. Production uses tcgen05 MMA.
For verification, BF16 matmul after dequant is correct.
"""
if self.weight is None:
raise RuntimeError("Weights not loaded")
# Dequantize NVFP4 → BF16
# weight: (out, in/2) uint8 — 2 FP4 values per byte
# weight_scale: (out, in/16) float8_e4m3fn — 1 scale per 16 elements
w_bf16 = self._dequant_nvfp4(self.weight, self.weight_scale, self.global_scale)
# Standard BF16 matmul
return torch.nn.functional.linear(x, w_bf16)
def _dequant_nvfp4(self, weight: torch.Tensor, scale: torch.Tensor,
global_scale: torch.Tensor) -> torch.Tensor:
"""Dequantize NVFP4 weight to BF16.
NVFP4: each 16-element group has 1 E4M3 scale.
Each byte contains 2 FP4 (E2M1) values: high nibble = second, low nibble = first.
Dequant = FP4 * E4M3_scale * global_scale
"""
out_dim = weight.shape[0]
in_dim_packed = weight.shape[1] # in_features / 2
in_features = in_dim_packed * 2
group_size = 16
# Unpack nibbles → (out, in) FP4 values
low_nibbles = (weight & 0x0F).to(torch.int8) # (out, in/2)
high_nibbles = (weight >> 4).to(torch.int8) # (out, in/2)
# FP4 E2M1 values: sign(1) + exp(2) + mantissa(1)
# Values: ±{0, 2, 3, 4, 6, 8, 12, Inf}
# Simple LUT approach for correctness
fp4_lut = torch.tensor([0, 2, 3, 4, 6, 8, 12, float('inf')],
dtype=torch.float32, device=weight.device)
# Handle sign bit (bit 3)
low_signs = (low_nibbles >> 3).bool()
low_vals = low_nibbles & 0x07
high_signs = (high_nibbles >> 3).bool()
high_vals = high_nibbles & 0x07
low_f = fp4_lut[low_vals] * torch.where(low_signs, -1.0, 1.0)
high_f = fp4_lut[high_vals] * torch.where(high_signs, -1.0, 1.0)
# Interleave: [low0, high0, low1, high1, ...]
w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features)
# Apply per-16-element scales
# scale: (out, in/16) float8_e4m3fn
scale_f = scale.float() # E4M3 → float32
if global_scale is not None:
scale_f = scale_f * global_scale.float()
# Expand scales: (out, in/16) → (out, in)
n_groups = scale_f.shape[1]
scale_expanded = scale_f.repeat_interleave(group_size, dim=1)
w_dequant = (w_f * scale_expanded).to(torch.bfloat16)
return w_dequant
class BF16Linear:
"""Standard BF16 linear layer (for o_a_proj, embeddings, etc)."""
def __init__(self, in_features: int, out_features: int):
self.in_features = in_features
self.out_features = out_features
self.weight = None
def load_from_checkpoint(self, weight: torch.Tensor):
self.weight = weight.cuda().to(torch.bfloat16)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.linear(x, self.weight)
# =====================================================================
# Single layer forward
# =====================================================================
def forward_layer(
x: torch.Tensor, # (T, hidden_size) BF16
layer_weights: dict, # checkpoint weights for this layer
layer_idx: int,
config: dict,
) -> torch.Tensor:
"""Forward pass through one transformer layer.
Simplified baseline: uses BF16 matmul after NVFP4 dequant.
This is mathematically equivalent to the tcgen05 MMA path.
"""
hidden_size = config["hidden_size"]
num_heads = config["num_attention_heads"] # 128 for Pro
head_dim = config["head_dim"] # 512
rope_dim = config["rope_dim"] # 64
n_hc = config.get("n_hc", 4)
nope_dim = head_dim - rope_dim # 448
# ---- mHC pre-block (simplified: identity for baseline) ----
# TODO: implement mHC properly with weights from attn_hc.*
# For baseline, just pass through
# ---- RMSNorm ----
norm_weight = layer_weights.get(f"model.layers.{layer_idx}.self_attn.kv_norm.weight")
# Actually the norm weight key might be different
# Let's skip norm for now (will add once we know the exact key names)
# ---- Attention ----
prefix = f"model.layers.{layer_idx}.self_attn"
# Q projection: q_a_proj (low-rank down) → q_b_proj (low-rank up)
q_a_w = layer_weights.get(f"{prefix}.q_a_proj.weight")
q_a_s = layer_weights.get(f"{prefix}.q_a_proj.weight_scale")
q_b_w = layer_weights.get(f"{prefix}.q_b_proj.weight")
q_b_s = layer_weights.get(f"{prefix}.q_b_proj.weight_scale")
if q_a_w is not None:
q_down = NVFP4Linear(hidden_size, q_a_w.shape[0])
q_down.load_from_checkpoint(q_a_w, q_a_s)
q_up = NVFP4Linear(q_a_w.shape[0] * 2, num_heads * head_dim) # 768*2=1536 for Pro
q_up.load_from_checkpoint(q_b_w, q_b_s)
c_Q = q_down.forward(x)
q = q_up.forward(c_Q)
else:
raise RuntimeError(f"Missing q_a_proj weights for layer {layer_idx}")
# KV projection
kv_w = layer_weights.get(f"{prefix}.kv_proj.weight")
kv_s = layer_weights.get(f"{prefix}.kv_proj.weight_scale")
if kv_w is not None:
kv_down = NVFP4Linear(hidden_size, kv_w.shape[0])
kv_down.load_from_checkpoint(kv_w, kv_s)
kv = kv_down.forward(x) # (T, kv_dim) — depends on layer type
else:
raise RuntimeError(f"Missing kv_proj weights for layer {layer_idx}")
# Reshape Q: (T, n_h * hd) → (n_h, T, hd)
T = q.shape[0]
q_heads = q.reshape(T, num_heads, head_dim).permute(1, 0, 2) # (n_h, T, hd)
# Apply partial RoPE (last 64 dims)
# For baseline, skip RoPE — the kernel handles it internally
# TODO: apply forward_rope_partial
# K/V reshape: MQA (1 KV head)
# kv shape depends on layer type:
# HCA: (T, head_dim) — single stream
# CSA: (T, 4*head_dim) — (Ca, Cb, Za, Zb)
# SWA: (T, head_dim)
kv_dim = kv.shape[-1]
if kv_dim == head_dim:
# HCA or SWA: single KV stream
k = kv.reshape(T, 1, head_dim).permute(1, 0, 2) # (1, T, hd)
v = k.clone()
elif kv_dim == 4 * head_dim:
# CSA: split into Ca, Cb, Za, Zb
ca, cb, za, zb = kv.chunk(4, dim=-1)
# For baseline, just use Ca as K, V = K (simplified)
k = ca.reshape(T, 1, head_dim).permute(1, 0, 2)
v = k.clone()
elif kv_dim == 2 * head_dim:
# HCA: (C, Z)
c, z = kv.chunk(2, dim=-1)
k = c.reshape(T, 1, head_dim).permute(1, 0, 2)
v = k.clone()
else:
raise RuntimeError(f"Unexpected kv_dim={kv_dim} for layer {layer_idx}")
# Run FMHA
from dsv4.kernels.attention.production import dsv4_attention
attn_out = dsv4_attention(q_heads, k, v) # (n_h, T, hd)
# Reshape back: (n_h, T, hd) → (T, n_h * hd)
attn_out = attn_out.permute(1, 0, 2).reshape(T, num_heads * head_dim)
# Output projection
# o_a_proj: grouped BF16 (n_h * hd, n_groups * o_rank)
# o_b_proj: NVFP4 (n_groups * o_rank, hidden_size)
o_a_w = layer_weights.get(f"{prefix}.o_a_proj.weight")
o_b_w = layer_weights.get(f"{prefix}.o_b_proj.weight")
o_b_s = layer_weights.get(f"{prefix}.o_b_proj.weight_scale")
if o_a_w is not None and o_b_w is not None:
# o_a is BF16 grouped linear — for baseline, treat as dense
o_a = BF16Linear(num_heads * head_dim, o_a_w.shape[0])
o_a.load_from_checkpoint(o_a_w)
o_b = NVFP4Linear(o_a_w.shape[0], hidden_size)
o_b.load_from_checkpoint(o_b_w, o_b_s)
attn_proj = o_a.forward(attn_out)
attn_out = o_b.forward(attn_proj)
else:
raise RuntimeError(f"Missing output projection weights for layer {layer_idx}")
# ---- Residual ----
x = x + attn_out
# ---- FFN (simplified: skip MoE for baseline) ----
# The FFN is a massive MoE with 384 experts, each ~3072×7168.
# For a baseline single-shot test, we can skip the FFN or use a
# simplified version. The FFN is not the kernel under test.
# TODO: implement MoE forward with NVFP4 GEMM
print(f" Layer {layer_idx}: attention OK, skipping FFN for baseline")
return x
# =====================================================================
# Main inference loop
# =====================================================================
def main():
print("=" * 70)
print("DSV4 Single-Shot Inference — Baseline Kernel Verification")
print("=" * 70)
# ---- Load config ----
config = load_config()
num_layers = config["num_hidden_layers"]
hidden_size = config["hidden_size"]
num_heads = config["num_attention_heads"]
head_dim = config["head_dim"]
print(f"\nModel: {config.get('model_type', 'deepseek_v4')}")
print(f"Layers: {num_layers}, Heads: {num_heads}, Head dim: {head_dim}")
print(f"Hidden: {hidden_size}")
# ---- Load tokenizer ----
print("\nLoading tokenizer...")
tokenizer = load_tokenizer()
input_ids = tokenizer.encode(PROMPT, return_tensors="pt").cuda()
print(f"Prompt: '{PROMPT}'")
print(f"Token IDs: {input_ids.tolist()}")
# ---- Load checkpoint reader ----
print("\nInitializing checkpoint reader...")
reader = CheckpointReader(CHECKPOINT_DIR)
# ---- Load embedding + final norm + lm_head ----
print("\nLoading embedding layer...")
embed_weight = reader.get_weight("model.embed_tokens.weight")
if embed_weight is not None:
embed = torch.nn.Embedding.from_pretrained(embed_weight.cuda().to(torch.bfloat16))
else:
raise RuntimeError("Missing embedding weights")
print("Loading final norm + lm_head...")
norm_weight = reader.get_weight("model.norm.weight")
if norm_weight is None:
# Try alternate key
norm_weight = reader.get_weight("model.model.norm.weight")
lm_head_weight = reader.get_weight("lm_head.weight")
if lm_head_weight is None:
# Often tied with embedding
lm_head_weight = embed_weight
print(" lm_head tied with embedding")
lm_head = BF16Linear(hidden_size, config["vocab_size"])
lm_head.load_from_checkpoint(lm_head_weight.cuda().to(torch.bfloat16))
# ---- Decode loop ----
print(f"\nStarting decode loop (max {MAX_NEW_TOKENS} tokens)...")
generated_ids = input_ids[0].tolist()
for step in range(MAX_NEW_TOKENS):
t0 = time.time()
current_pos = len(generated_ids) - 1
token_id = torch.tensor([generated_ids[-1]], dtype=torch.long, device='cuda')
# Embed
x = embed(token_id).unsqueeze(0) # (1, 1, hidden_size) → (1, hidden_size)
x = x.squeeze(0) # (1, hidden_size) for T=1 decode
# Process through layers
for layer_idx in range(num_layers):
layer_weights = reader.get_layer_weights(layer_idx)
if not layer_weights:
print(f" WARNING: No weights for layer {layer_idx}, skipping")
continue
x = forward_layer(x, layer_weights, layer_idx, config)
# Free layer weights after use
del layer_weights
if layer_idx % 10 == 9:
reader.clear_cache()
torch.cuda.empty_cache()
# Final norm + lm_head
if norm_weight is not None:
x_f = x.float()
rms = x_f.pow(2).mean(dim=-1, keepdim=True).add(1e-6).rsqrt()
x = (x_f * rms * norm_weight.cuda().float()).to(torch.bfloat16)
logits = lm_head.forward(x) # (1, vocab_size)
next_token = torch.argmax(logits, dim=-1).item()
generated_ids.append(next_token)
token_str = tokenizer.decode([next_token])
elapsed = time.time() - t0
print(f" Step {step}: token={next_token} '{token_str}' ({elapsed:.2f}s)")
# Stop on EOS
if next_token == tokenizer.eos_token_id:
break
# ---- Output ----
output_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
print(f"\n{'='*70}")
print(f"Input: '{PROMPT}'")
print(f"Output: '{output_text}'")
print(f"{'='*70}")
# Verify
if "Paris" in output_text or "paris" in output_text.lower():
print("✅ PASSED: Model produced 'Paris' — kernel is correct!")
else:
print(f"⚠️ Model did not produce 'Paris'. Output: {output_text}")
print(" This could be due to: missing FFN, missing RoPE, missing mHC,")
print(" incomplete weight loading, or other integration gaps.")
print(" The kernel FMHA itself is verified separately (cos ≥ 0.999993).")
if __name__ == "__main__":
main()