570 lines
23 KiB
Python
570 lines
23 KiB
Python
#!/usr/bin/env python3
|
|
"""Single-shot DSV4 inference — 8-GPU with mHC + MoE + KV cache.
|
|
|
|
Full model forward pass with all architectural components:
|
|
- mHC (Manifold-Constrained Hyper-Connections)
|
|
- Q low-rank + KV projection
|
|
- RoPE (partial, last 64 dims)
|
|
- Production FMHA kernel (tcgen05 MMA)
|
|
- Grouped output projection (wo_a BMM + wo_b NVFP4)
|
|
- Routed MoE (384 experts, top-6, hash + dense routing)
|
|
- Shared expert FFN (SwiGLU with clamping)
|
|
- KV cache across decode steps
|
|
|
|
Usage (on B200):
|
|
source /root/dsv4-nvfp4-workspace/venv/bin/activate
|
|
cd /root/dsv4-nvfp4-workspace/kernel
|
|
python3 single_shot_inference.py
|
|
"""
|
|
import os, sys, time, json, math
|
|
import torch
|
|
from pathlib import Path
|
|
|
|
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
MAX_NEW_TOKENS = 10
|
|
PROMPT = "The capital of France is"
|
|
NUM_GPUS = 8
|
|
|
|
# =====================================================================
|
|
# NVFP4 dequantization
|
|
# =====================================================================
|
|
|
|
FP4_LUT = torch.tensor([0., 2., 3., 4., 6., 8., 12., 24.])
|
|
|
|
def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2):
|
|
out_dim = weight.shape[0]
|
|
in_packed = weight.shape[1]
|
|
in_features = in_packed * 2
|
|
low = (weight & 0x0F).to(torch.int8)
|
|
high = (weight >> 4).to(torch.int8)
|
|
low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long()
|
|
high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long()
|
|
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
|
|
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
|
|
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)
|
|
w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features)
|
|
scale_f = weight_scale.float() * weight_scale_2.float()
|
|
scale_expanded = scale_f.repeat_interleave(16, dim=1)
|
|
return (w_f * scale_expanded).bfloat16()
|
|
|
|
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
|
|
w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2)
|
|
return torch.nn.functional.linear(x, w)
|
|
|
|
def bf16_linear(x, weight):
|
|
return torch.nn.functional.linear(x, weight.bfloat16())
|
|
|
|
|
|
# =====================================================================
|
|
# mHC
|
|
# =====================================================================
|
|
|
|
class mHCBlock:
|
|
def __init__(self, hidden_dim=7168, n_hc=4, sinkhorn_repeat=20, device='cuda'):
|
|
self.n_hc = n_hc
|
|
self.K = n_hc * hidden_dim
|
|
self.sinkhorn_repeat = sinkhorn_repeat
|
|
self.device = device
|
|
self.fn = None
|
|
self.hc_scale = None
|
|
self.hc_base = None
|
|
self.rms_eps = 1e-6
|
|
self.hc_pre_eps = 0.0
|
|
self.hc_sinkhorn_eps = 1e-6
|
|
self.hc_post_mult_value = 2.0
|
|
|
|
def load_from_checkpoint(self, fn, base, scale):
|
|
self.fn = fn.to(device=self.device, dtype=torch.float32).contiguous()
|
|
self.hc_base = base.to(device=self.device, dtype=torch.float32).contiguous()
|
|
self.hc_scale = scale.to(device=self.device, dtype=torch.float32).contiguous()
|
|
|
|
def pre_block(self, residual):
|
|
n = self.n_hc
|
|
K = self.K
|
|
T = residual.shape[0]
|
|
res_flat = residual.reshape(T, K).float()
|
|
mixes = torch.matmul(res_flat, self.fn.t())
|
|
sqrsum = res_flat.square().sum(dim=-1, keepdim=True)
|
|
mixes = mixes * torch.rsqrt(sqrsum / K + self.rms_eps)
|
|
|
|
pre_logits = mixes[:, :n] * self.hc_scale[0] + self.hc_base[:n]
|
|
pre_mix = torch.sigmoid(pre_logits) + self.hc_pre_eps
|
|
post_logits = mixes[:, n:2*n] * self.hc_scale[1] + self.hc_base[n:2*n]
|
|
post_mix = torch.sigmoid(post_logits) * self.hc_post_mult_value
|
|
comb_logits = (mixes[:, 2*n:].reshape(T, n, n) * self.hc_scale[2]
|
|
+ self.hc_base[2*n:].reshape(1, n, n))
|
|
comb_mix = torch.softmax(comb_logits, dim=-1) + self.hc_sinkhorn_eps
|
|
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + self.hc_sinkhorn_eps)
|
|
for _ in range(self.sinkhorn_repeat - 1):
|
|
comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + self.hc_sinkhorn_eps)
|
|
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + self.hc_sinkhorn_eps)
|
|
|
|
layer_input = (pre_mix.unsqueeze(-1) * residual.float()).sum(dim=1).bfloat16()
|
|
return layer_input, (post_mix, comb_mix)
|
|
|
|
def post_block(self, residual, F_out, ctx):
|
|
post_mix, comb_mix = ctx
|
|
mixed_residual = torch.einsum('tij,tjh->tjh', comb_mix, residual.float())
|
|
post_term = post_mix.unsqueeze(-1) * F_out.unsqueeze(1).float()
|
|
residual_next = mixed_residual + post_term
|
|
# Emergency RMSNorm (remove once MoE provides balance)
|
|
_T = residual_next.shape[0]
|
|
rn_f = residual_next.reshape(_T, self.n_hc, -1)
|
|
rms = rn_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
|
return (rn_f * rms).bfloat16()
|
|
|
|
|
|
# =====================================================================
|
|
# RoPE
|
|
# =====================================================================
|
|
|
|
def build_rope_cache(max_pos, head_dim, rope_dim, device, theta=10000.0):
|
|
half = rope_dim // 2
|
|
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
|
|
angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs)
|
|
return torch.cos(angles).to(device), torch.sin(angles).to(device)
|
|
|
|
def apply_rope(x, positions, cos_cache, sin_cache, rope_dim):
|
|
T, n_h, hd = x.shape
|
|
nope = hd - rope_dim
|
|
cos = cos_cache[positions].unsqueeze(1).to(x.dtype)
|
|
sin = sin_cache[positions].unsqueeze(1).to(x.dtype)
|
|
out = x.clone()
|
|
out[:, :, nope:][..., 0::2] = x[:, :, nope:][..., 0::2] * cos - x[:, :, nope:][..., 1::2] * sin
|
|
out[:, :, nope:][..., 1::2] = x[:, :, nope:][..., 0::2] * sin + x[:, :, nope:][..., 1::2] * cos
|
|
return out
|
|
|
|
|
|
# =====================================================================
|
|
# Simple KV cache — BF16 flat, one tensor per layer
|
|
# =====================================================================
|
|
|
|
class SimpleKVCache:
|
|
"""Per-layer KV cache for decode. Stores BF16 K,V accumulated across steps.
|
|
MQA: 1 KV head, so cache is (1, seq_len, hd) per layer."""
|
|
def __init__(self, n_layers, head_dim, max_seq=8192, device='cuda:0'):
|
|
self.hd = head_dim
|
|
self.max_seq = max_seq
|
|
self.device = device
|
|
self.k = [torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device) for _ in range(n_layers)]
|
|
self.v = [torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device) for _ in range(n_layers)]
|
|
self.len = [0] * n_layers # current sequence length per layer
|
|
|
|
def append(self, layer_idx, k_new, v_new):
|
|
"""Append new K,V. k_new: (1, 1, hd), v_new: (1, 1, hd)."""
|
|
pos = self.len[layer_idx]
|
|
self.k[layer_idx][0, pos] = k_new[0, 0]
|
|
self.v[layer_idx][0, pos] = v_new[0, 0]
|
|
self.len[layer_idx] = pos + 1
|
|
|
|
def get(self, layer_idx):
|
|
"""Get K,V up to current length. Returns (1, seq_len, hd), (1, seq_len, hd)."""
|
|
l = self.len[layer_idx]
|
|
return self.k[layer_idx][:, :l], self.v[layer_idx][:, :l]
|
|
|
|
|
|
# =====================================================================
|
|
# Routed MoE forward
|
|
# =====================================================================
|
|
|
|
def moe_forward(x, w, li, cfg, token_id):
|
|
"""Run routed MoE + shared expert.
|
|
|
|
x: (1, H) BF16 — input after FFN mHC pre_block
|
|
Returns: (1, H) BF16 — combined expert output
|
|
"""
|
|
H = cfg["hidden_size"]
|
|
n_experts = cfg["n_routed_experts"] # 384
|
|
top_k = cfg["num_experts_per_tok"] if "num_experts_per_tok" in cfg else 6
|
|
routed_scaling = cfg.get("routed_scaling_factor", 2.5)
|
|
swiglu_limit = cfg.get("swiglu_limit", 10.0)
|
|
mlp_inter = cfg["moe_intermediate_size"] # 3072
|
|
|
|
# ---- Hash routing ----
|
|
# For decode, first 3 layers use hash routing (token ID lookup)
|
|
# Remaining layers use dense routing (weight projection)
|
|
is_hash = li < 3 # Hash routing for first 3 layers
|
|
|
|
expert_ids = None
|
|
expert_weights = None
|
|
|
|
if is_hash:
|
|
# tid2eid: (vocab_size, top_k) int64
|
|
tid2eid = w[f"model.layers.{li}.mlp.gate.tid2eid"]
|
|
tid = token_id.item() if token_id.numel() == 1 else token_id[0].item()
|
|
expert_ids = tid2eid[tid] # (top_k,) int64
|
|
expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k
|
|
else:
|
|
# Dense routing: gate weight (n_experts, H) BF16
|
|
gate_w = w[f"model.layers.{li}.mlp.gate.weight"] # (384, 7168) BF16
|
|
logits = bf16_linear(x, gate_w) # (1, 384) BF16
|
|
# activation = sqrt(softplus(logits))
|
|
activated = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6).bfloat16()
|
|
# Top-k
|
|
scores, indices = activated.float().topk(top_k, dim=-1) # (1, 6)
|
|
expert_ids = indices[0] # (6,)
|
|
# Renormalize
|
|
expert_weights = (scores[0] / scores[0].sum()).float()
|
|
|
|
# ---- Run selected experts ----
|
|
expert_outputs = []
|
|
for i, eid in enumerate(expert_ids):
|
|
eid_int = eid.item()
|
|
epre = f"model.layers.{li}.mlp.experts.{eid_int}"
|
|
|
|
gate = nvfp4_linear(x, w[f"{epre}.gate_proj.weight"],
|
|
w[f"{epre}.gate_proj.weight_scale"],
|
|
w[f"{epre}.gate_proj.weight_scale_2"])
|
|
up = nvfp4_linear(x, w[f"{epre}.up_proj.weight"],
|
|
w[f"{epre}.up_proj.weight_scale"],
|
|
w[f"{epre}.up_proj.weight_scale_2"])
|
|
|
|
# SiLU + clamp
|
|
silu_out = torch.nn.functional.silu(gate.float()).clamp(-swiglu_limit, swiglu_limit)
|
|
hidden = (silu_out * up.float()).bfloat16()
|
|
|
|
down = nvfp4_linear(hidden, w[f"{epre}.down_proj.weight"],
|
|
w[f"{epre}.down_proj.weight_scale"],
|
|
w[f"{epre}.down_proj.weight_scale_2"])
|
|
expert_outputs.append(down)
|
|
|
|
# Weighted combine + scaling
|
|
routed_out = torch.zeros_like(x)
|
|
for i, (out, wt) in enumerate(zip(expert_outputs, expert_weights)):
|
|
routed_out = routed_out + (out.float() * wt).bfloat16()
|
|
routed_out = (routed_out.float() * routed_scaling).bfloat16()
|
|
|
|
# ---- Shared expert ----
|
|
se_pre = f"model.layers.{li}.mlp.shared_experts"
|
|
se_gate_w = w.get(f"{se_pre}.gate_proj.weight")
|
|
if se_gate_w is not None:
|
|
gate = nvfp4_linear(x, se_gate_w,
|
|
w[f"{se_pre}.gate_proj.weight_scale"],
|
|
w[f"{se_pre}.gate_proj.weight_scale_2"])
|
|
up = nvfp4_linear(x, w[f"{se_pre}.up_proj.weight"],
|
|
w[f"{se_pre}.up_proj.weight_scale"],
|
|
w[f"{se_pre}.up_proj.weight_scale_2"])
|
|
silu_out = torch.nn.functional.silu(gate.float()).clamp(-swiglu_limit, swiglu_limit)
|
|
hidden = (silu_out * up.float()).bfloat16()
|
|
shared_out = nvfp4_linear(hidden, w[f"{se_pre}.down_proj.weight"],
|
|
w[f"{se_pre}.down_proj.weight_scale"],
|
|
w[f"{se_pre}.down_proj.weight_scale_2"])
|
|
else:
|
|
shared_out = torch.zeros_like(x)
|
|
|
|
return routed_out + shared_out
|
|
|
|
|
|
# =====================================================================
|
|
# Weight loading
|
|
# =====================================================================
|
|
|
|
def load_all_weights(checkpoint_dir, num_layers):
|
|
from safetensors.torch import load_file
|
|
cdir = Path(checkpoint_dir)
|
|
index_path = cdir / "model.safetensors.index.json"
|
|
weight_map = {}
|
|
if index_path.exists():
|
|
with open(index_path) as f:
|
|
weight_map = json.load(f).get("weight_map", {})
|
|
shard_names = set(weight_map.values()) if weight_map else {
|
|
f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96)
|
|
}
|
|
print(f"Loading {len(shard_names)} shards...")
|
|
all_weights = {}
|
|
loaded = 0
|
|
for shard_name in sorted(shard_names):
|
|
if not (cdir / shard_name).exists():
|
|
continue
|
|
data = load_file(str(cdir / shard_name))
|
|
all_weights.update(data)
|
|
loaded += 1
|
|
if loaded % 20 == 0:
|
|
print(f" {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors")
|
|
print(f" Done: {len(all_weights)} tensors")
|
|
|
|
layer_weights = {}
|
|
global_weights = {}
|
|
print("Assigning to GPUs...")
|
|
for key, tensor in all_weights.items():
|
|
if key.startswith("model.layers."):
|
|
li = int(key.split(".")[2])
|
|
target_gpu = li % NUM_GPUS
|
|
target_device = f"cuda:{target_gpu}"
|
|
if li not in layer_weights:
|
|
layer_weights[li] = {"_device": target_device, "_gpu": target_gpu}
|
|
layer_weights[li][key] = tensor.to(target_device)
|
|
elif key.startswith("model.embed_tokens"):
|
|
global_weights[key] = tensor.to("cuda:0")
|
|
elif key.startswith("model.norm"):
|
|
global_weights[key] = tensor.to("cuda:0")
|
|
elif key.startswith("lm_head"):
|
|
global_weights[key] = tensor.to("cuda:0")
|
|
|
|
for gpu in range(NUM_GPUS):
|
|
alloc = torch.cuda.memory_allocated(gpu) / 1e9
|
|
print(f" GPU {gpu}: {alloc:.1f}GB")
|
|
return layer_weights, global_weights
|
|
|
|
|
|
# =====================================================================
|
|
# Single layer forward
|
|
# =====================================================================
|
|
|
|
def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc,
|
|
kv_cache, token_id, decode_pos):
|
|
"""Forward one layer with mHC + MoE + KV cache."""
|
|
device = X_l.device
|
|
H = cfg["hidden_size"]
|
|
n_h = cfg["num_attention_heads"]
|
|
hd = cfg["head_dim"]
|
|
rd = cfg["qk_rope_head_dim"]
|
|
o_rank = cfg["o_lora_rank"]
|
|
o_groups = cfg["o_groups"]
|
|
n_hc = 4
|
|
pre = f"model.layers.{li}.self_attn"
|
|
T = X_l.shape[0]
|
|
heads_per_group = n_h // o_groups
|
|
group_input_dim = heads_per_group * hd
|
|
|
|
# ==== mHC pre_block (attention) ====
|
|
x_in, attn_ctx = attn_mhc.pre_block(X_l)
|
|
|
|
# ==== Q projection ====
|
|
c_Q = nvfp4_linear(x_in, w[f"{pre}.q_a_proj.weight"],
|
|
w[f"{pre}.q_a_proj.weight_scale"],
|
|
w[f"{pre}.q_a_proj.weight_scale_2"])
|
|
q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"],
|
|
w[f"{pre}.q_b_proj.weight_scale"],
|
|
w[f"{pre}.q_b_proj.weight_scale_2"])
|
|
|
|
# ==== KV projection ====
|
|
kv = nvfp4_linear(x_in, w[f"{pre}.kv_proj.weight"],
|
|
w[f"{pre}.kv_proj.weight_scale"],
|
|
w[f"{pre}.kv_proj.weight_scale_2"])
|
|
|
|
# ==== Reshape for attention ====
|
|
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd)
|
|
k_new = kv.reshape(T, 1, hd).permute(1, 0, 2) # (1, T, hd)
|
|
v_new = k_new.clone()
|
|
|
|
# ==== KV cache: append new K,V ====
|
|
kv_cache.append(0, k_new, v_new)
|
|
k_full, v_full = kv_cache.get(0) # (1, seq_len, hd)
|
|
seq_len = k_full.shape[1]
|
|
|
|
# ==== RoPE ====
|
|
# Apply RoPE to Q (at current position)
|
|
q_pos = torch.tensor([decode_pos], dtype=torch.long, device=device)
|
|
q_heads = apply_rope(q_heads, q_pos, rope_cos, rope_sin, rd)
|
|
|
|
# Apply RoPE to K (at each position in the cache)
|
|
k_positions = torch.arange(seq_len, dtype=torch.long, device=device)
|
|
k_full_3d = k_full.permute(1, 0, 2) # (seq_len, 1, hd) for RoPE
|
|
k_full_3d = apply_rope(k_full_3d, k_positions, rope_cos, rope_sin, rd)
|
|
k_full = k_full_3d.permute(1, 0, 2) # (1, seq_len, hd) — RoPE'd
|
|
|
|
# ==== FMHA ====
|
|
from dsv4.kernels.attention.production import dsv4_attention
|
|
attn_out = dsv4_attention(q_heads, k_full, v_full)
|
|
attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd)
|
|
|
|
# ==== Output projection ====
|
|
attn_grouped = attn_out.reshape(T, o_groups, heads_per_group, hd).reshape(T, o_groups, group_input_dim)
|
|
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16()
|
|
oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim)
|
|
attn_for_bmm = attn_grouped.permute(1, 0, 2)
|
|
grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2))
|
|
grouped_flat = grouped_out.permute(1, 0, 2).reshape(T, o_groups * o_rank)
|
|
F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"],
|
|
w[f"{pre}.o_b_proj.weight_scale"],
|
|
w[f"{pre}.o_b_proj.weight_scale_2"])
|
|
|
|
# ==== mHC post_block (attention) ====
|
|
X_l = attn_mhc.post_block(X_l, F_attn, attn_ctx)
|
|
|
|
# ==== mHC pre_block (FFN) ====
|
|
x_ffn, ffn_ctx = ffn_mhc.pre_block(X_l)
|
|
|
|
# ==== MoE + shared expert ====
|
|
F_ffn = moe_forward(x_ffn, w, li, cfg, token_id)
|
|
|
|
# ==== mHC post_block (FFN) ====
|
|
X_l = ffn_mhc.post_block(X_l, F_ffn, ffn_ctx)
|
|
|
|
return X_l
|
|
|
|
|
|
# =====================================================================
|
|
# Main
|
|
# =====================================================================
|
|
|
|
def main():
|
|
t_start = time.time()
|
|
print("=" * 70)
|
|
print("DSV4 Single-Shot Inference — Full Pipeline (mHC+MoE+KV)")
|
|
print("=" * 70)
|
|
|
|
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
|
cfg = json.load(f)
|
|
n_layers = cfg["num_hidden_layers"]
|
|
H = cfg["hidden_size"]
|
|
n_h = cfg["num_attention_heads"]
|
|
hd = cfg["head_dim"]
|
|
rd = cfg["qk_rope_head_dim"]
|
|
n_hc = 4
|
|
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
|
|
print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}")
|
|
|
|
# ==== Phase 1: Load weights ====
|
|
print(f"\n{'='*70}\nPhase 1: Loading weights\n{'='*70}")
|
|
layer_weights, global_weights = load_all_weights(CHECKPOINT_DIR, n_layers)
|
|
t_loaded = time.time()
|
|
print(f"Weight loading: {t_loaded - t_start:.1f}s")
|
|
|
|
# ==== Build mHC blocks ====
|
|
print("Building mHC blocks...")
|
|
attn_mhc_blocks = {}
|
|
ffn_mhc_blocks = {}
|
|
for li in range(n_layers):
|
|
gpu = li % NUM_GPUS
|
|
dev = f"cuda:{gpu}"
|
|
|
|
for prefix, blocks in [(f"model.layers.{li}.attn_hc", attn_mhc_blocks),
|
|
(f"model.layers.{li}.ffn_hc", ffn_mhc_blocks)]:
|
|
fn = layer_weights[li].get(f"{prefix}.fn")
|
|
base = layer_weights[li].get(f"{prefix}.base")
|
|
scale = layer_weights[li].get(f"{prefix}.scale")
|
|
if fn is not None and base is not None and scale is not None:
|
|
mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
|
mhc.load_from_checkpoint(fn, base, scale)
|
|
blocks[li] = mhc
|
|
|
|
print(f" attn mHC: {len(attn_mhc_blocks)}, ffn mHC: {len(ffn_mhc_blocks)}")
|
|
|
|
# ==== Global weights ====
|
|
torch.cuda.set_device(0)
|
|
embed_w = global_weights.get("model.embed_tokens.weight")
|
|
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16())
|
|
lm_w = global_weights.get("lm_head.weight", embed_w).bfloat16()
|
|
final_norm_w = global_weights.get("model.norm.weight")
|
|
rope_caches = {g: build_rope_cache(8192, hd, rd, f"cuda:{g}") for g in range(NUM_GPUS)}
|
|
|
|
# ==== KV cache (gpu0, moves to target GPU per layer) ====
|
|
kv_caches = {}
|
|
for li in range(n_layers):
|
|
kv_caches[li] = SimpleKVCache(n_layers=1, head_dim=hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}")
|
|
|
|
# ==== Phase 2: Compile ====
|
|
print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}")
|
|
from dsv4.kernels.attention.production import dsv4_attention
|
|
dummy_q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda:0')
|
|
dummy_k = torch.randn(1, 1, hd, dtype=torch.bfloat16, device='cuda:0')
|
|
try:
|
|
_ = dsv4_attention(dummy_q, dummy_k, dummy_k.clone())
|
|
print(" FMHA: compiled OK")
|
|
except Exception as e:
|
|
print(f" FMHA error: {e}")
|
|
t_compiled = time.time()
|
|
print(f"Compile: {t_compiled - t_loaded:.1f}s")
|
|
|
|
# ==== Phase 3: Inference ====
|
|
print(f"\n{'='*70}\nPhase 3: Inference\n{'='*70}")
|
|
from transformers import AutoTokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
|
input_ids = tokenizer.encode(PROMPT, return_tensors="pt").cuda()
|
|
print(f"Prompt: '{PROMPT}' → {input_ids.tolist()}")
|
|
|
|
generated = input_ids[0].tolist()
|
|
|
|
# ==== Prefill: process prompt tokens to fill KV cache ====
|
|
print(f"Prefilling {len(generated)} prompt tokens...")
|
|
for prefill_idx, tid_val in enumerate(generated):
|
|
tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0')
|
|
emb = embed(tid)
|
|
X = emb.unsqueeze(1).expand(-1, n_hc, -1).clone()
|
|
|
|
for li in range(n_layers):
|
|
gpu = li % NUM_GPUS
|
|
target_device = f"cuda:{gpu}"
|
|
if X.device != torch.device(target_device):
|
|
X = X.to(target_device)
|
|
torch.cuda.set_device(gpu)
|
|
|
|
attn_mhc = attn_mhc_blocks.get(li)
|
|
ffn_mhc = ffn_mhc_blocks.get(li)
|
|
rc, rs = rope_caches[gpu]
|
|
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs,
|
|
attn_mhc, ffn_mhc, kv_caches[li], tid, prefill_idx)
|
|
|
|
X = X.to('cuda:0')
|
|
torch.cuda.set_device(0)
|
|
print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)")
|
|
|
|
# ==== Decode: generate new tokens ====
|
|
print(f"\nDecoding (max {MAX_NEW_TOKENS} new tokens)...")
|
|
|
|
for step in range(MAX_NEW_TOKENS):
|
|
t0 = time.time()
|
|
# Current token (last in the sequence)
|
|
tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0')
|
|
decode_pos = len(all_tokens) - 1 # absolute position
|
|
|
|
# Embed → mHC init state
|
|
emb = embed(tid) # (1, H) on gpu0
|
|
X = emb.unsqueeze(1).expand(-1, n_hc, -1).clone() # (1, n_hc, H)
|
|
|
|
# Process layers
|
|
for li in range(n_layers):
|
|
gpu = li % NUM_GPUS
|
|
target_device = f"cuda:{gpu}"
|
|
if X.device != torch.device(target_device):
|
|
X = X.to(target_device)
|
|
torch.cuda.set_device(gpu)
|
|
|
|
attn_mhc = attn_mhc_blocks.get(li)
|
|
ffn_mhc = ffn_mhc_blocks.get(li)
|
|
rc, rs = rope_caches[gpu]
|
|
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs,
|
|
attn_mhc, ffn_mhc, kv_caches[li], tid, decode_pos)
|
|
|
|
# Back to gpu0
|
|
X = X.to('cuda:0')
|
|
torch.cuda.set_device(0)
|
|
|
|
# Read out stream 0 → RMSNorm → lm_head
|
|
x_out = X[:, 0, :]
|
|
if final_norm_w is not None:
|
|
xf = x_out.float()
|
|
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
|
x_out = (xf * rms * final_norm_w.float()).bfloat16()
|
|
|
|
logits = torch.nn.functional.linear(x_out, lm_w)
|
|
next_id = torch.argmax(logits, dim=-1).item()
|
|
generated.append(next_id)
|
|
all_tokens.append(next_id)
|
|
|
|
tok_str = tokenizer.decode([next_id])
|
|
dt = time.time() - t0
|
|
has_nan = torch.isnan(logits.float()).any().item()
|
|
lmin, lmax = logits.float().min().item(), logits.float().max().item()
|
|
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan}")
|
|
|
|
if has_nan:
|
|
print(" NaN — stopping")
|
|
break
|
|
if next_id == tokenizer.eos_token_id:
|
|
break
|
|
|
|
out = tokenizer.decode(generated, skip_special_tokens=True)
|
|
total = time.time() - t_start
|
|
print(f"\n{'='*70}")
|
|
print(f"Input: '{PROMPT}'")
|
|
print(f"Output: '{out}'")
|
|
print(f"Total: {total:.1f}s")
|
|
print(f"{'='*70}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|