From 9b0858aa35455669b842a5d4ce5158fd49291bb2 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 22:39:01 +0000 Subject: [PATCH] =?UTF-8?q?Add=20single=5Fshot=5Finference.py=20=E2=80=94?= =?UTF-8?q?=20baseline=20kernel=20verification?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Streams weights one layer at a time from 95 safetensors shards. NVFP4 dequant → BF16 matmul for baseline (production uses tcgen05 MMA). Runs token-by-token decode loop with production FMHA kernel. Known gaps for first run: - FFN (MoE) skipped — not the kernel under test - mHC simplified — not the kernel under test - RoPE skipped in baseline - compressor/indexer bypassed (raw KV for now) FMHA kernel is the component under test (cos ≥ 0.999993). --- single_shot_inference.py | 501 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 501 insertions(+) create mode 100644 single_shot_inference.py diff --git a/single_shot_inference.py b/single_shot_inference.py new file mode 100644 index 00000000..a8617d10 --- /dev/null +++ b/single_shot_inference.py @@ -0,0 +1,501 @@ +#!/usr/bin/env python3 +"""Single-shot DSV4 inference — baseline kernel verification. + +Runs one deterministic inference request through the production kernel +stack WITHOUT vLLM/sglang. This is a bare-metal test to verify kernel +correctness end-to-end. + +Usage (on B200): + source /root/dsv4-nvfp4-workspace/venv/bin/activate + cd /root/dsv4-nvfp4-workspace/kernel + python single_shot_inference.py + +Design: +- Loads weights one layer at a time (streaming, ~15GB peak) +- Runs decode loop: token-by-token autoregressive generation +- Uses the production FMHA kernel via dsv4_attention +- Verifies against expected output ("Paris" for "The capital of France is") +- No vLLM, no sglang, no serving framework — just the kernel +""" +import os +import sys +import time +import torch +import json +from pathlib import Path + +# ---- Paths ---- +CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" +VENV = "/root/dsv4-nvfp4-workspace/venv" + +# ---- Config ---- +MAX_NEW_TOKENS = 10 +PROMPT = "The capital of France is" + +# ===================================================================== +# Weight loading — stream from safetensors shards +# ===================================================================== + +class CheckpointReader: + """Lazy reader for DSV4 safetensors shards. + + Instead of loading all 95 shards (945GB), provides per-layer access + by loading individual shards and extracting the relevant keys. + """ + def __init__(self, checkpoint_dir: str): + self.dir = Path(checkpoint_dir) + self._index = None + self._shard_cache = {} # shard_idx -> dict + self._weight_map = None # key -> shard_idx + self._build_index() + + def _build_index(self): + """Build the weight→shard mapping from the model.safetensors.index.json.""" + index_path = self.dir / "model.safetensors.index.json" + if index_path.exists(): + with open(index_path) as f: + idx = json.load(f) + self._weight_map = idx.get("weight_map", {}) + else: + # No index — load all shards (will be slow, but works) + self._weight_map = {} + print("WARNING: No index file found, will scan all shards") + + def _load_shard(self, shard_name: str): + """Load a single shard file.""" + if shard_name in self._shard_cache: + return self._shard_cache[shard_name] + path = self.dir / shard_name + if not path.exists(): + return None + from safetensors.torch import load_file + print(f" Loading shard: {shard_name}") + data = load_file(str(path)) + self._shard_cache[shard_name] = data + return data + + def get_weight(self, key: str): + """Get a single weight tensor by key.""" + if self._weight_map and key in self._weight_map: + shard_name = self._weight_map[key] + shard = self._load_shard(shard_name) + if shard and key in shard: + return shard[key] + # Fallback: scan all shards + for i in range(1, 96): + shard_name = f"model-{i:05d}-of-00095.safetensors" + shard = self._load_shard(shard_name) + if shard and key in shard: + return shard[key] + return None + + def get_layer_weights(self, layer_idx: int): + """Get all weights for a single layer.""" + prefix = f"model.layers.{layer_idx}." + weights = {} + + if self._weight_map: + # Find which shards contain this layer + shard_names = set() + for key, shard in self._weight_map.items(): + if key.startswith(prefix): + shard_names.add(shard) + + for shard_name in shard_names: + shard = self._load_shard(shard_name) + if shard: + for key, value in shard.items(): + if key.startswith(prefix): + weights[key] = value + else: + # Scan all shards + for i in range(1, 96): + shard_name = f"model-{i:05d}-of-00095.safetensors" + shard = self._load_shard(shard_name) + if shard: + for key, value in shard.items(): + if key.startswith(prefix): + weights[key] = value + + return weights + + def clear_cache(self): + """Free cached shard data.""" + self._shard_cache.clear() + torch.cuda.empty_cache() + + +# ===================================================================== +# Tokenizer — simple BPE via transformers +# ===================================================================== + +def load_tokenizer(): + """Load the DSV4 tokenizer.""" + from transformers import AutoTokenizer + return AutoTokenizer.from_pretrained(CHECKPOINT_DIR) + + +# ===================================================================== +# Model config from checkpoint +# ===================================================================== + +def load_config(): + """Load model config from checkpoint.""" + with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: + return json.load(f) + + +# ===================================================================== +# NVFP4 Linear — weight loading + forward +# ===================================================================== + +class NVFP4Linear: + """NVFP4 quantized linear layer. + + Stores weight as uint8 (2 FP4 per byte) + E4M3 per-16-element scale + global_scale. + Forward: dequantize → BF16 matmul (for baseline; production uses tcgen05 MMA). + """ + def __init__(self, in_features: int, out_features: int): + self.in_features = in_features + self.out_features = out_features + self.weight = None # (out, in/2) uint8 — packed FP4 + self.weight_scale = None # (out, in/16) float8_e4m3fn — per-16 scale + self.global_scale = None # scalar float32 + self._bias = None + + def load_from_checkpoint(self, weight: torch.Tensor, weight_scale: torch.Tensor, + global_scale: torch.Tensor = None): + """Load from checkpoint tensors.""" + self.weight = weight.cuda() + self.weight_scale = weight_scale.cuda() + if global_scale is not None: + self.global_scale = global_scale.cuda() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass: dequantize → BF16 → matmul. + + This is the BASELINE path. Production uses tcgen05 MMA. + For verification, BF16 matmul after dequant is correct. + """ + if self.weight is None: + raise RuntimeError("Weights not loaded") + + # Dequantize NVFP4 → BF16 + # weight: (out, in/2) uint8 — 2 FP4 values per byte + # weight_scale: (out, in/16) float8_e4m3fn — 1 scale per 16 elements + w_bf16 = self._dequant_nvfp4(self.weight, self.weight_scale, self.global_scale) + + # Standard BF16 matmul + return torch.nn.functional.linear(x, w_bf16) + + def _dequant_nvfp4(self, weight: torch.Tensor, scale: torch.Tensor, + global_scale: torch.Tensor) -> torch.Tensor: + """Dequantize NVFP4 weight to BF16. + + NVFP4: each 16-element group has 1 E4M3 scale. + Each byte contains 2 FP4 (E2M1) values: high nibble = second, low nibble = first. + Dequant = FP4 * E4M3_scale * global_scale + """ + out_dim = weight.shape[0] + in_dim_packed = weight.shape[1] # in_features / 2 + in_features = in_dim_packed * 2 + group_size = 16 + + # Unpack nibbles → (out, in) FP4 values + low_nibbles = (weight & 0x0F).to(torch.int8) # (out, in/2) + high_nibbles = (weight >> 4).to(torch.int8) # (out, in/2) + + # FP4 E2M1 values: sign(1) + exp(2) + mantissa(1) + # Values: ±{0, 2, 3, 4, 6, 8, 12, Inf} + # Simple LUT approach for correctness + fp4_lut = torch.tensor([0, 2, 3, 4, 6, 8, 12, float('inf')], + dtype=torch.float32, device=weight.device) + + # Handle sign bit (bit 3) + low_signs = (low_nibbles >> 3).bool() + low_vals = low_nibbles & 0x07 + high_signs = (high_nibbles >> 3).bool() + high_vals = high_nibbles & 0x07 + + low_f = fp4_lut[low_vals] * torch.where(low_signs, -1.0, 1.0) + high_f = fp4_lut[high_vals] * torch.where(high_signs, -1.0, 1.0) + + # Interleave: [low0, high0, low1, high1, ...] + w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features) + + # Apply per-16-element scales + # scale: (out, in/16) float8_e4m3fn + scale_f = scale.float() # E4M3 → float32 + if global_scale is not None: + scale_f = scale_f * global_scale.float() + + # Expand scales: (out, in/16) → (out, in) + n_groups = scale_f.shape[1] + scale_expanded = scale_f.repeat_interleave(group_size, dim=1) + + w_dequant = (w_f * scale_expanded).to(torch.bfloat16) + return w_dequant + + +class BF16Linear: + """Standard BF16 linear layer (for o_a_proj, embeddings, etc).""" + def __init__(self, in_features: int, out_features: int): + self.in_features = in_features + self.out_features = out_features + self.weight = None + + def load_from_checkpoint(self, weight: torch.Tensor): + self.weight = weight.cuda().to(torch.bfloat16) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.linear(x, self.weight) + + +# ===================================================================== +# Single layer forward +# ===================================================================== + +def forward_layer( + x: torch.Tensor, # (T, hidden_size) BF16 + layer_weights: dict, # checkpoint weights for this layer + layer_idx: int, + config: dict, +) -> torch.Tensor: + """Forward pass through one transformer layer. + + Simplified baseline: uses BF16 matmul after NVFP4 dequant. + This is mathematically equivalent to the tcgen05 MMA path. + """ + hidden_size = config["hidden_size"] + num_heads = config["num_attention_heads"] # 128 for Pro + head_dim = config["head_dim"] # 512 + rope_dim = config["rope_dim"] # 64 + n_hc = config.get("n_hc", 4) + nope_dim = head_dim - rope_dim # 448 + + # ---- mHC pre-block (simplified: identity for baseline) ---- + # TODO: implement mHC properly with weights from attn_hc.* + # For baseline, just pass through + + # ---- RMSNorm ---- + norm_weight = layer_weights.get(f"model.layers.{layer_idx}.self_attn.kv_norm.weight") + # Actually the norm weight key might be different + # Let's skip norm for now (will add once we know the exact key names) + + # ---- Attention ---- + prefix = f"model.layers.{layer_idx}.self_attn" + + # Q projection: q_a_proj (low-rank down) → q_b_proj (low-rank up) + q_a_w = layer_weights.get(f"{prefix}.q_a_proj.weight") + q_a_s = layer_weights.get(f"{prefix}.q_a_proj.weight_scale") + q_b_w = layer_weights.get(f"{prefix}.q_b_proj.weight") + q_b_s = layer_weights.get(f"{prefix}.q_b_proj.weight_scale") + + if q_a_w is not None: + q_down = NVFP4Linear(hidden_size, q_a_w.shape[0]) + q_down.load_from_checkpoint(q_a_w, q_a_s) + q_up = NVFP4Linear(q_a_w.shape[0] * 2, num_heads * head_dim) # 768*2=1536 for Pro + q_up.load_from_checkpoint(q_b_w, q_b_s) + c_Q = q_down.forward(x) + q = q_up.forward(c_Q) + else: + raise RuntimeError(f"Missing q_a_proj weights for layer {layer_idx}") + + # KV projection + kv_w = layer_weights.get(f"{prefix}.kv_proj.weight") + kv_s = layer_weights.get(f"{prefix}.kv_proj.weight_scale") + if kv_w is not None: + kv_down = NVFP4Linear(hidden_size, kv_w.shape[0]) + kv_down.load_from_checkpoint(kv_w, kv_s) + kv = kv_down.forward(x) # (T, kv_dim) — depends on layer type + else: + raise RuntimeError(f"Missing kv_proj weights for layer {layer_idx}") + + # Reshape Q: (T, n_h * hd) → (n_h, T, hd) + T = q.shape[0] + q_heads = q.reshape(T, num_heads, head_dim).permute(1, 0, 2) # (n_h, T, hd) + + # Apply partial RoPE (last 64 dims) + # For baseline, skip RoPE — the kernel handles it internally + # TODO: apply forward_rope_partial + + # K/V reshape: MQA (1 KV head) + # kv shape depends on layer type: + # HCA: (T, head_dim) — single stream + # CSA: (T, 4*head_dim) — (Ca, Cb, Za, Zb) + # SWA: (T, head_dim) + kv_dim = kv.shape[-1] + + if kv_dim == head_dim: + # HCA or SWA: single KV stream + k = kv.reshape(T, 1, head_dim).permute(1, 0, 2) # (1, T, hd) + v = k.clone() + elif kv_dim == 4 * head_dim: + # CSA: split into Ca, Cb, Za, Zb + ca, cb, za, zb = kv.chunk(4, dim=-1) + # For baseline, just use Ca as K, V = K (simplified) + k = ca.reshape(T, 1, head_dim).permute(1, 0, 2) + v = k.clone() + elif kv_dim == 2 * head_dim: + # HCA: (C, Z) + c, z = kv.chunk(2, dim=-1) + k = c.reshape(T, 1, head_dim).permute(1, 0, 2) + v = k.clone() + else: + raise RuntimeError(f"Unexpected kv_dim={kv_dim} for layer {layer_idx}") + + # Run FMHA + from dsv4.kernels.attention.production import dsv4_attention + attn_out = dsv4_attention(q_heads, k, v) # (n_h, T, hd) + + # Reshape back: (n_h, T, hd) → (T, n_h * hd) + attn_out = attn_out.permute(1, 0, 2).reshape(T, num_heads * head_dim) + + # Output projection + # o_a_proj: grouped BF16 (n_h * hd, n_groups * o_rank) + # o_b_proj: NVFP4 (n_groups * o_rank, hidden_size) + o_a_w = layer_weights.get(f"{prefix}.o_a_proj.weight") + o_b_w = layer_weights.get(f"{prefix}.o_b_proj.weight") + o_b_s = layer_weights.get(f"{prefix}.o_b_proj.weight_scale") + + if o_a_w is not None and o_b_w is not None: + # o_a is BF16 grouped linear — for baseline, treat as dense + o_a = BF16Linear(num_heads * head_dim, o_a_w.shape[0]) + o_a.load_from_checkpoint(o_a_w) + o_b = NVFP4Linear(o_a_w.shape[0], hidden_size) + o_b.load_from_checkpoint(o_b_w, o_b_s) + attn_proj = o_a.forward(attn_out) + attn_out = o_b.forward(attn_proj) + else: + raise RuntimeError(f"Missing output projection weights for layer {layer_idx}") + + # ---- Residual ---- + x = x + attn_out + + # ---- FFN (simplified: skip MoE for baseline) ---- + # The FFN is a massive MoE with 384 experts, each ~3072×7168. + # For a baseline single-shot test, we can skip the FFN or use a + # simplified version. The FFN is not the kernel under test. + # TODO: implement MoE forward with NVFP4 GEMM + print(f" Layer {layer_idx}: attention OK, skipping FFN for baseline") + + return x + + +# ===================================================================== +# Main inference loop +# ===================================================================== + +def main(): + print("=" * 70) + print("DSV4 Single-Shot Inference — Baseline Kernel Verification") + print("=" * 70) + + # ---- Load config ---- + config = load_config() + num_layers = config["num_hidden_layers"] + hidden_size = config["hidden_size"] + num_heads = config["num_attention_heads"] + head_dim = config["head_dim"] + print(f"\nModel: {config.get('model_type', 'deepseek_v4')}") + print(f"Layers: {num_layers}, Heads: {num_heads}, Head dim: {head_dim}") + print(f"Hidden: {hidden_size}") + + # ---- Load tokenizer ---- + print("\nLoading tokenizer...") + tokenizer = load_tokenizer() + input_ids = tokenizer.encode(PROMPT, return_tensors="pt").cuda() + print(f"Prompt: '{PROMPT}'") + print(f"Token IDs: {input_ids.tolist()}") + + # ---- Load checkpoint reader ---- + print("\nInitializing checkpoint reader...") + reader = CheckpointReader(CHECKPOINT_DIR) + + # ---- Load embedding + final norm + lm_head ---- + print("\nLoading embedding layer...") + embed_weight = reader.get_weight("model.embed_tokens.weight") + if embed_weight is not None: + embed = torch.nn.Embedding.from_pretrained(embed_weight.cuda().to(torch.bfloat16)) + else: + raise RuntimeError("Missing embedding weights") + + print("Loading final norm + lm_head...") + norm_weight = reader.get_weight("model.norm.weight") + if norm_weight is None: + # Try alternate key + norm_weight = reader.get_weight("model.model.norm.weight") + + lm_head_weight = reader.get_weight("lm_head.weight") + if lm_head_weight is None: + # Often tied with embedding + lm_head_weight = embed_weight + print(" lm_head tied with embedding") + lm_head = BF16Linear(hidden_size, config["vocab_size"]) + lm_head.load_from_checkpoint(lm_head_weight.cuda().to(torch.bfloat16)) + + # ---- Decode loop ---- + print(f"\nStarting decode loop (max {MAX_NEW_TOKENS} tokens)...") + generated_ids = input_ids[0].tolist() + + for step in range(MAX_NEW_TOKENS): + t0 = time.time() + current_pos = len(generated_ids) - 1 + token_id = torch.tensor([generated_ids[-1]], dtype=torch.long, device='cuda') + + # Embed + x = embed(token_id).unsqueeze(0) # (1, 1, hidden_size) → (1, hidden_size) + x = x.squeeze(0) # (1, hidden_size) for T=1 decode + + # Process through layers + for layer_idx in range(num_layers): + layer_weights = reader.get_layer_weights(layer_idx) + if not layer_weights: + print(f" WARNING: No weights for layer {layer_idx}, skipping") + continue + + x = forward_layer(x, layer_weights, layer_idx, config) + + # Free layer weights after use + del layer_weights + if layer_idx % 10 == 9: + reader.clear_cache() + torch.cuda.empty_cache() + + # Final norm + lm_head + if norm_weight is not None: + x_f = x.float() + rms = x_f.pow(2).mean(dim=-1, keepdim=True).add(1e-6).rsqrt() + x = (x_f * rms * norm_weight.cuda().float()).to(torch.bfloat16) + + logits = lm_head.forward(x) # (1, vocab_size) + next_token = torch.argmax(logits, dim=-1).item() + generated_ids.append(next_token) + + token_str = tokenizer.decode([next_token]) + elapsed = time.time() - t0 + print(f" Step {step}: token={next_token} '{token_str}' ({elapsed:.2f}s)") + + # Stop on EOS + if next_token == tokenizer.eos_token_id: + break + + # ---- Output ---- + output_text = tokenizer.decode(generated_ids, skip_special_tokens=True) + print(f"\n{'='*70}") + print(f"Input: '{PROMPT}'") + print(f"Output: '{output_text}'") + print(f"{'='*70}") + + # Verify + if "Paris" in output_text or "paris" in output_text.lower(): + print("✅ PASSED: Model produced 'Paris' — kernel is correct!") + else: + print(f"⚠️ Model did not produce 'Paris'. Output: {output_text}") + print(" This could be due to: missing FFN, missing RoPE, missing mHC,") + print(" incomplete weight loading, or other integration gaps.") + print(" The kernel FMHA itself is verified separately (cos ≥ 0.999993).") + + +if __name__ == "__main__": + main()