Add mHC (Manifold-Constrained Hyper-Connections) to single_shot

- Full mHC pre_block/post_block with Sinkhorn-Knopp normalization
- Dynamic A_l (sigmoid), B_l (Birkhoff polytope), C_l (2*sigmoid)
- Checkpoint: attn_hc.fn (24,28672) + base (24,) + scale (3,)
- Two mHC blocks per layer: attn_hc + ffn_hc
- Removed emergency RMSNorm — mHC handles normalization properly
- X_l: (1, n_hc=4, H) residual state, init from embedding broadcast
This commit is contained in:
2026-05-30 23:45:18 +00:00
parent f567c20539
commit d003c4b7cc

View File

@@ -1,14 +1,19 @@
#!/usr/bin/env python3
"""Single-shot DSV4 inference — 8-GPU pipeline parallel.
"""Single-shot DSV4 inference — 8-GPU pipeline parallel with mHC.
Loads the full NVFP4 checkpoint across 8 B200 GPUs (round-robin layer assignment).
Each GPU holds ~8 layers of weights in HBM. Activations move between GPUs at
layer boundaries via cudaMemcpy (fast on NVLink).
Loads the full NVFP4 checkpoint across 8 B200 GPUs. Includes:
- mHC (Manifold-Constrained Hyper-Connections) — load-bearing residual
- Q low-rank projection + KV projection
- RoPE (partial, last 64 dims)
- Production FMHA kernel (tcgen05 MMA, hd=512, 128 heads)
- Output projection: wo_a (grouped BMM) → wo_b (NVFP4)
- Shared expert FFN (SwiGLU)
- NVFP4 dequant → BF16 matmul baseline for linear layers
Pipeline:
1. Load all 95 shards, assign each layer's weights to its GPU
2. JIT-compile kernels (one-time)
3. Decode loop: embed → layer 0 (gpu0) → layer 1 (gpu0) → ... → layer 8 (gpu1) → ... → norm → lm_head
Missing (causing incorrect output):
- Routed MoE experts (384 experts, top-6)
- KV cache across decode steps
- Compressor + indexer (CSA/HCA compressed KV)
Usage (on B200):
source /root/dsv4-nvfp4-workspace/venv/bin/activate
@@ -31,42 +36,127 @@ NUM_GPUS = 8
FP4_LUT = torch.tensor([0., 2., 3., 4., 6., 8., 12., 24.])
def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2):
"""Dequantize NVFP4 weight to BF16. All tensors must be on same device."""
out_dim = weight.shape[0]
in_packed = weight.shape[1]
in_features = in_packed * 2
low = (weight & 0x0F).to(torch.int8)
high = (weight >> 4).to(torch.int8)
low_sign = (low >> 3).bool()
low_idx = (low & 0x07).long()
high_sign = (high >> 3).bool()
high_idx = (high & 0x07).long()
low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long()
high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long()
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)
w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features)
scale_f = weight_scale.float() * weight_scale_2.float()
scale_expanded = scale_f.repeat_interleave(16, dim=1)
return (w_f * scale_expanded).bfloat16()
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
"""NVFP4 linear: dequant → BF16 matmul."""
w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2)
return torch.nn.functional.linear(x, w)
def bf16_linear(x, weight):
"""BF16 linear."""
return torch.nn.functional.linear(x, weight.bfloat16())
# =====================================================================
# mHC — Manifold-Constrained Hyper-Connections
# =====================================================================
def sinkhorn_knopp(M, t_max=20, eps=1e-6):
"""Project (T, n, n) positive matrices onto Birkhoff polytope."""
for _ in range(t_max):
M = M / (M.sum(dim=-1, keepdim=True) + eps)
M = M / (M.sum(dim=-2, keepdim=True) + eps)
return M
class mHCBlock:
"""One mHC block (attention or FFN).
Checkpoint weight mapping:
fn: (24, 28672) FP32 = stacked [W_pre(4,K); W_res(16,K); W_post(4,K)]
base: (24,) FP32 = bias, split as [S_pre(4); S_res(16); S_post(4)]
scale: (3,) FP32 = [alpha_pre, alpha_res, alpha_post]
"""
def __init__(self, hidden_dim=7168, n_hc=4, t_max=20, device='cuda'):
self.d = hidden_dim
self.n_hc = n_hc
self.K = n_hc * hidden_dim # 28672
self.t_max = t_max
self.device = device
self.W_stacked = None # (24, K) FP32
self.bias = None # (24,) FP32
self.alphas = None # (3,) FP32
def load_from_checkpoint(self, fn, base, scale):
"""Load from checkpoint tensors. All on target device, FP32."""
self.W_stacked = fn.to(device=self.device, dtype=torch.float32).contiguous()
self.bias = base.to(device=self.device, dtype=torch.float32).contiguous()
self.alphas = scale.to(device=self.device, dtype=torch.float32).contiguous()
def _dynamic_params(self, X_l):
"""Compute A_l, B_l, C_l from residual state.
X_l: (T, n_hc, d) BF16
Returns: A_l (T, n_hc), B_l (T, n_hc, n_hc) FP32, C_l (T, n_hc)
"""
T, n, d = X_l.shape
n_hc = self.n_hc
# Flatten and project with RMSNorm
X_flat = X_l.reshape(T, self.K) # (T, K) BF16
# RMSNorm
x_f32 = X_flat.float()
rms = x_f32.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
x_normed = x_f32 * rms # (T, K) FP32
# Project: (T, K) @ (24, K)^T → (T, 24)
proj = torch.nn.functional.linear(x_normed.bfloat16(), self.W_stacked.bfloat16()).float()
proj = proj + self.bias.unsqueeze(0) # add bias
# Split into A, B, C
i0, i1, i2 = n_hc, n_hc + n_hc * n_hc, 24
A_raw = proj[:, :i0] # (T, 4)
B_raw = proj[:, i0:i1] # (T, 16)
C_raw = proj[:, i1:i2] # (T, 4)
# Split bias into S_pre, S_res, S_post
S_pre = self.bias[:n_hc]
S_res = self.bias[n_hc:n_hc + n_hc * n_hc]
S_post = self.bias[n_hc + n_hc * n_hc:]
# Apply gating + biases
a_pre, a_res, a_post = self.alphas[0], self.alphas[1], self.alphas[2]
A_tilde = a_pre * A_raw + S_pre.unsqueeze(0)
B_tilde = a_res * B_raw + S_res.unsqueeze(0)
C_tilde = a_post * C_raw + S_post.unsqueeze(0)
# Constraints
A_l = torch.sigmoid(A_tilde).bfloat16() # (T, 4) ∈ (0,1)
C_l = (2.0 * torch.sigmoid(C_tilde)).bfloat16() # (T, 4) ∈ (0,2)
B_exp = torch.exp(B_tilde).reshape(T, n_hc, n_hc) # (T, 4, 4)
B_l = sinkhorn_knopp(B_exp, self.t_max) # FP32, doubly stochastic
return A_l, B_l, C_l
def pre_block(self, X_l):
"""X_l: (T, n_hc, d) → x_in: (T, d), ctx"""
A_l, B_l, C_l = self._dynamic_params(X_l)
# x_in = A_l @ X_l: (T, 1, n_hc) bmm (T, n_hc, d) → (T, 1, d) → (T, d)
x_in = torch.bmm(A_l.unsqueeze(1).float(), X_l.float()).squeeze(1).bfloat16()
return x_in, (B_l, C_l)
def post_block(self, X_l, F_out, ctx):
"""X_l: (T, n_hc, d), F_out: (T, d) → X_next: (T, n_hc, d)"""
B_l, C_l = ctx
# X_next = B_l @ X_l + C_l ⊗ F_out
BX = torch.bmm(B_l, X_l.float()) # (T, n_hc, d) FP32
CF = C_l.unsqueeze(-1).float() * F_out.unsqueeze(1).float() # (T, n_hc, d) FP32
return (BX + CF).bfloat16()
# =====================================================================
# RoPE
# =====================================================================
@@ -74,110 +164,73 @@ def bf16_linear(x, weight):
def build_rope_cache(max_pos, head_dim, rope_dim, device, theta=10000.0):
half = rope_dim // 2
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
positions = torch.arange(max_pos, dtype=torch.float32)
angles = torch.outer(positions, freqs)
angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs)
return torch.cos(angles).to(device), torch.sin(angles).to(device)
def apply_rope(x, positions, cos_cache, sin_cache, rope_dim):
T, n_h, hd = x.shape
nope = hd - rope_dim
half = rope_dim // 2
cos = cos_cache[positions].unsqueeze(1).to(x.dtype)
sin = sin_cache[positions].unsqueeze(1).to(x.dtype)
x_rope = x[:, :, nope:]
even = x_rope[:, :, 0::2]
odd = x_rope[:, :, 1::2]
out = x.clone()
out[:, :, nope:][..., 0::2] = even * cos - odd * sin
out[:, :, nope:][..., 1::2] = even * sin + odd * cos
out[:, :, nope:][..., 0::2] = x[:, :, nope:][..., 0::2] * cos - x[:, :, nope:][..., 1::2] * sin
out[:, :, nope:][..., 1::2] = x[:, :, nope:][..., 0::2] * sin + x[:, :, nope:][..., 1::2] * cos
return out
# =====================================================================
# Checkpoint loading — load all shards, group by layer, assign to GPU
# Weight loading
# =====================================================================
def load_all_weights(checkpoint_dir, num_layers):
"""Load all 95 shards and organize weights by layer, moving to target GPU.
Returns: dict mapping layer_idx → dict of weight tensors (on target GPU)
plus global weights (embed, norm, lm_head) on gpu0
"""
from safetensors.torch import load_file
from collections import defaultdict
cdir = Path(checkpoint_dir)
# Load the index for fast shard lookup
index_path = cdir / "model.safetensors.index.json"
weight_map = {}
if index_path.exists():
with open(index_path) as f:
weight_map = json.load(f).get("weight_map", {})
else:
weight_map = {}
# Organize: which shard files do we need?
shard_names = set(weight_map.values()) if weight_map else {
f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96)
}
# Load all shards (one at a time to limit CPU RAM)
print(f"Loading {len(shard_names)} shards from checkpoint...")
all_weights = {} # key → tensor (CPU)
print(f"Loading {len(shard_names)} shards...")
all_weights = {}
loaded = 0
for shard_name in sorted(shard_names):
shard_path = cdir / shard_name
if not shard_path.exists():
if not (cdir / shard_name).exists():
continue
data = load_file(str(shard_path))
data = load_file(str(cdir / shard_name))
all_weights.update(data)
loaded += 1
if loaded % 10 == 0:
print(f" Loaded {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors")
if loaded % 20 == 0:
print(f" {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors")
print(f" Done: {len(all_weights)} tensors")
print(f" Done: {len(all_weights)} tensors loaded to CPU")
layer_weights = {}
global_weights = {}
# Group by layer and assign to GPU
# Layer i goes to GPU (i // 8) — 8 layers per GPU
# Actually: round-robin — layer i goes to GPU (i % 8)
# This balances load better (8 layers per GPU, ~15GB each)
layer_weights = {} # layer_idx → dict of tensors on target GPU
global_weights = {} # embed, norm, lm_head → on gpu0
print("Assigning layers to GPUs...")
print("Assigning to GPUs...")
for key, tensor in all_weights.items():
# Determine which layer this weight belongs to
if key.startswith("model.layers."):
parts = key.split(".")
layer_idx = int(parts[2])
target_gpu = layer_idx % NUM_GPUS
if layer_idx not in layer_weights:
layer_weights[layer_idx] = {"_device": f"cuda:{target_gpu}"}
# Move to target GPU
layer_weights[layer_idx][key] = tensor.to(f"cuda:{target_gpu}")
li = int(key.split(".")[2])
target_gpu = li % NUM_GPUS
target_device = f"cuda:{target_gpu}"
if li not in layer_weights:
layer_weights[li] = {"_device": target_device, "_gpu": target_gpu}
layer_weights[li][key] = tensor.to(target_device)
elif key.startswith("model.embed_tokens"):
global_weights[key] = tensor.to("cuda:0")
elif key.startswith("model.norm"):
global_weights[key] = tensor.to("cuda:0")
elif key.startswith("lm_head"):
global_weights[key] = tensor.to("cuda:0")
# Print per-GPU memory usage
for gpu in range(NUM_GPUS):
torch.cuda.set_device(gpu)
allocated = torch.cuda.memory_allocated(gpu) / 1e9
reserved = torch.cuda.memory_reserved(gpu) / 1e9
print(f" GPU {gpu}: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved")
alloc = torch.cuda.memory_allocated(gpu) / 1e9
print(f" GPU {gpu}: {alloc:.1f}GB")
return layer_weights, global_weights
@@ -186,95 +239,95 @@ def load_all_weights(checkpoint_dir, num_layers):
# Single layer forward
# =====================================================================
def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
"""Forward one layer. x: (1, hidden) BF16 → (1, hidden) BF16."""
device = x.device
def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc):
"""Forward one layer with mHC.
X_l: (1, n_hc, H) BF16 → (1, n_hc, H) BF16
"""
device = X_l.device
H = cfg["hidden_size"]
n_h = cfg["num_attention_heads"]
hd = cfg["head_dim"]
rd = cfg["qk_rope_head_dim"]
o_rank = cfg["o_lora_rank"]
o_groups = cfg["o_groups"]
n_hc = 4
pre = f"model.layers.{li}.self_attn"
T = x.shape[0]
heads_per_group = n_h // o_groups # 8
group_input_dim = heads_per_group * hd # 4096
T = X_l.shape[0]
heads_per_group = n_h // o_groups
group_input_dim = heads_per_group * hd
# ---- Q projection: q_a (down) → q_b (up) ----
c_Q = nvfp4_linear(x, w[f"{pre}.q_a_proj.weight"],
# ==== mHC pre_block (attention) ====
x_in, attn_ctx = attn_mhc.pre_block(X_l) # x_in: (T, H) BF16
# ==== Q projection ====
c_Q = nvfp4_linear(x_in, w[f"{pre}.q_a_proj.weight"],
w[f"{pre}.q_a_proj.weight_scale"],
w[f"{pre}.q_a_proj.weight_scale_2"])
q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"],
w[f"{pre}.q_b_proj.weight_scale"],
w[f"{pre}.q_b_proj.weight_scale_2"])
# ---- KV projection ----
kv = nvfp4_linear(x, w[f"{pre}.kv_proj.weight"],
# ==== KV projection ====
kv = nvfp4_linear(x_in, w[f"{pre}.kv_proj.weight"],
w[f"{pre}.kv_proj.weight_scale"],
w[f"{pre}.kv_proj.weight_scale_2"])
# ---- Reshape for attention ----
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd)
k = kv.reshape(T, 1, hd).permute(1, 0, 2) # (1, T, hd) — MQA
# ==== Reshape for attention ====
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2)
k = kv.reshape(T, 1, hd).permute(1, 0, 2)
v = k.clone()
# ---- Apply RoPE ----
# ==== RoPE ====
pos = torch.tensor([0], dtype=torch.long, device=device)
q_heads = apply_rope(q_heads, pos, rope_cos, rope_sin, rd)
k = apply_rope(k, pos, rope_cos, rope_sin, rd)
# ---- FMHA ----
# ==== FMHA ====
from dsv4.kernels.attention.production import dsv4_attention
attn_out = dsv4_attention(q_heads, k, v) # (n_h, T, hd)
attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd) # (T, n_h*hd)
# ---- Output projection: wo_a (BF16 grouped BMM) → wo_b (NVFP4) ----
attn_grouped = attn_out.reshape(T, o_groups, heads_per_group, hd)
attn_grouped = attn_grouped.reshape(T, o_groups, group_input_dim)
attn_out = dsv4_attention(q_heads, k, v)
attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd)
# ==== Output projection ====
attn_grouped = attn_out.reshape(T, o_groups, heads_per_group, hd).reshape(T, o_groups, group_input_dim)
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16()
oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim)
attn_for_bmm = attn_grouped.permute(1, 0, 2) # (16, 1, 4096)
grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) # (16, 1, o_rank)
attn_for_bmm = attn_grouped.permute(1, 0, 2)
grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2))
grouped_flat = grouped_out.permute(1, 0, 2).reshape(T, o_groups * o_rank)
F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"],
w[f"{pre}.o_b_proj.weight_scale"],
w[f"{pre}.o_b_proj.weight_scale_2"])
attn_proj = nvfp4_linear(grouped_flat,
w[f"{pre}.o_b_proj.weight"],
w[f"{pre}.o_b_proj.weight_scale"],
w[f"{pre}.o_b_proj.weight_scale_2"])
# ==== mHC post_block (attention) ====
X_l = attn_mhc.post_block(X_l, F_attn, attn_ctx) # (T, n_hc, H)
# ---- Residual + emergency RMSNorm (mHC missing) ----
x = x + attn_proj
x = x.clamp(-65504, 65504)
xf = x.float()
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
x = (xf * rms).bfloat16()
# ==== mHC pre_block (FFN) ====
x_ffn, ffn_ctx = ffn_mhc.pre_block(X_l)
# ---- FFN: shared expert ----
# ==== FFN: shared expert ====
se_pre = f"model.layers.{li}.mlp.shared_experts"
se_gate_w = w.get(f"{se_pre}.gate_proj.weight")
F_ffn = torch.zeros_like(x_ffn)
if se_gate_w is not None:
gate = nvfp4_linear(x, se_gate_w,
gate = nvfp4_linear(x_ffn, se_gate_w,
w[f"{se_pre}.gate_proj.weight_scale"],
w[f"{se_pre}.gate_proj.weight_scale_2"])
up = nvfp4_linear(x, w[f"{se_pre}.up_proj.weight"],
up = nvfp4_linear(x_ffn, w[f"{se_pre}.up_proj.weight"],
w[f"{se_pre}.up_proj.weight_scale"],
w[f"{se_pre}.up_proj.weight_scale_2"])
ffn_out = nvfp4_linear(
F_ffn = nvfp4_linear(
torch.nn.functional.silu(gate) * up,
w[f"{se_pre}.down_proj.weight"],
w[f"{se_pre}.down_proj.weight_scale"],
w[f"{se_pre}.down_proj.weight_scale_2"],
)
x = x + ffn_out
x = x.clamp(-65504, 65504)
xf = x.float()
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
x = (xf * rms).bfloat16()
return x
# ==== mHC post_block (FFN) ====
X_l = ffn_mhc.post_block(X_l, F_ffn, ffn_ctx)
return X_l
# =====================================================================
@@ -284,10 +337,9 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
def main():
t_start = time.time()
print("=" * 70)
print("DSV4 Single-Shot Inference — 8-GPU Pipeline Parallel")
print("DSV4 Single-Shot Inference — 8-GPU with mHC")
print("=" * 70)
# Config
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
cfg = json.load(f)
n_layers = cfg["num_hidden_layers"]
@@ -295,58 +347,69 @@ def main():
n_h = cfg["num_attention_heads"]
hd = cfg["head_dim"]
rd = cfg["qk_rope_head_dim"]
n_hc = 4
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
print(f"GPUs: {NUM_GPUS}, ~{n_layers // NUM_GPUS} layers per GPU")
# ---- Phase 1: Load weights ----
print(f"\n{'='*70}")
print("Phase 1: Loading weights across 8 GPUs")
print(f"{'='*70}")
# ==== Phase 1: Load weights ====
print(f"\n{'='*70}\nPhase 1: Loading weights\n{'='*70}")
layer_weights, global_weights = load_all_weights(CHECKPOINT_DIR, n_layers)
t_loaded = time.time()
print(f"Weight loading: {t_loaded - t_start:.1f}s")
# Set default device to gpu0 for embed/lm_head
torch.cuda.set_device(0)
# ==== Build mHC blocks per layer ====
print("Building mHC blocks...")
attn_mhc_blocks = {}
ffn_mhc_blocks = {}
for li in range(n_layers):
gpu = li % NUM_GPUS
dev = f"cuda:{gpu}"
# Attention mHC
attn_fn = layer_weights[li].get(f"model.layers.{li}.attn_hc.fn")
attn_base = layer_weights[li].get(f"model.layers.{li}.attn_hc.base")
attn_scale = layer_weights[li].get(f"model.layers.{li}.attn_hc.scale")
if attn_fn is not None and attn_base is not None and attn_scale is not None:
attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
attn_mhc.load_from_checkpoint(attn_fn, attn_base, attn_scale)
attn_mhc_blocks[li] = attn_mhc
# FFN mHC
ffn_fn = layer_weights[li].get(f"model.layers.{li}.ffn_hc.fn")
ffn_base = layer_weights[li].get(f"model.layers.{li}.ffn_hc.base")
ffn_scale = layer_weights[li].get(f"model.layers.{li}.ffn_hc.scale")
if ffn_fn is not None and ffn_base is not None and ffn_scale is not None:
ffn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
ffn_mhc.load_from_checkpoint(ffn_fn, ffn_base, ffn_scale)
ffn_mhc_blocks[li] = ffn_mhc
# Embedding on gpu0
print(f" attn mHC: {len(attn_mhc_blocks)} layers")
print(f" ffn mHC: {len(ffn_mhc_blocks)} layers")
# ==== Global weights (gpu0) ====
torch.cuda.set_device(0)
embed_w = global_weights.get("model.embed_tokens.weight")
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16())
# lm_head on gpu0
lm_w = global_weights.get("lm_head.weight", embed_w).bfloat16()
# Final norm on gpu0
final_norm_w = global_weights.get("model.norm.weight")
# Build RoPE caches — one per GPU
rope_caches = {}
for gpu in range(NUM_GPUS):
rope_caches[gpu] = build_rope_cache(8192, hd, rd, f"cuda:{gpu}")
# RoPE caches per GPU
rope_caches = {g: build_rope_cache(8192, hd, rd, f"cuda:{g}") for g in range(NUM_GPUS)}
# ---- Phase 2: JIT compile kernels ----
print(f"\n{'='*70}")
print("Phase 2: JIT compiling kernels")
print(f"{'='*70}")
# Trigger FMHA kernel compile on gpu0 with a dummy forward
# This compiles the C API .so and caches it for all subsequent calls
# ==== Phase 2: Compile ====
print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}")
from dsv4.kernels.attention.production import dsv4_attention
dummy_q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda:0')
dummy_k = torch.randn(1, 1, hd, dtype=torch.bfloat16, device='cuda:0')
dummy_v = dummy_k.clone()
try:
_ = dsv4_attention(dummy_q, dummy_k, dummy_v)
print(" FMHA kernel: compiled OK")
_ = dsv4_attention(dummy_q, dummy_k, dummy_k.clone())
print(" FMHA: compiled OK")
except Exception as e:
print(f" FMHA kernel compile error: {e}")
print(f" FMHA error: {e}")
t_compiled = time.time()
print(f"Kernel compilation: {t_compiled - t_loaded:.1f}s")
# ---- Phase 3: Inference ----
print(f"\n{'='*70}")
print("Phase 3: Inference")
print(f"{'='*70}")
print(f"Compile: {t_compiled - t_loaded:.1f}s")
# ==== Phase 3: Inference ====
print(f"\n{'='*70}\nPhase 3: Inference\n{'='*70}")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
input_ids = tokenizer.encode(PROMPT, return_tensors="pt").cuda()
@@ -358,60 +421,59 @@ def main():
t0 = time.time()
tid = torch.tensor([generated[-1]], dtype=torch.long, device='cuda:0')
# Embed (gpu0)
x = embed(tid) # (1, H) on gpu0
# Embed → mHC init state
emb = embed(tid) # (1, H) on gpu0
X = emb.unsqueeze(1).expand(-1, n_hc, -1).clone() # (1, n_hc, H)
# Process layers — move x to the right GPU for each layer
# Process layers
for li in range(n_layers):
target_gpu = li % NUM_GPUS
target_device = f"cuda:{target_gpu}"
gpu = li % NUM_GPUS
target_device = f"cuda:{gpu}"
if X.device != torch.device(target_device):
X = X.to(target_device)
torch.cuda.set_device(gpu)
# Move activation to layer's GPU and set as active device
if x.device != torch.device(target_device):
x = x.to(target_device)
torch.cuda.set_device(target_gpu)
lw = layer_weights[li]
rc, rs = rope_caches[target_gpu]
x = forward_layer(x, lw, li, cfg, rc, rs)
attn_mhc = attn_mhc_blocks.get(li)
ffn_mhc = ffn_mhc_blocks.get(li)
rc, rs = rope_caches[gpu]
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs, attn_mhc, ffn_mhc)
# Move back to gpu0 for final norm + lm_head
x = x.to('cuda:0')
# Back to gpu0
X = X.to('cuda:0')
torch.cuda.set_device(0)
# Final norm
# Read out stream 0 → RMSNorm → lm_head
x_out = X[:, 0, :] # (1, H)
if final_norm_w is not None:
xf = x.float()
xf = x_out.float()
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
x = (xf * rms * final_norm_w.float()).bfloat16()
x_out = (xf * rms * final_norm_w.float()).bfloat16()
# lm_head
logits = torch.nn.functional.linear(x, lm_w)
logits = torch.nn.functional.linear(x_out, lm_w)
next_id = torch.argmax(logits, dim=-1).item()
generated.append(next_id)
tok_str = tokenizer.decode([next_id])
dt = time.time() - t0
if step == 0:
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) [first step includes compile cache]")
else:
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s)")
# Check for NaN
has_nan = torch.isnan(logits.float()).any().item()
logit_range = f"[{logits.float().min().item():.1f}, {logits.float().max().item():.1f}]"
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) logits={logit_range} nan={has_nan}")
if has_nan:
print(" NaN detected, stopping")
break
if next_id == tokenizer.eos_token_id:
break
# ---- Output ----
out = tokenizer.decode(generated, skip_special_tokens=True)
total_time = time.time() - t_start
total = time.time() - t_start
print(f"\n{'='*70}")
print(f"Input: '{PROMPT}'")
print(f"Output: '{out}'")
print(f"Total time: {total_time:.1f}s (load: {t_loaded-t_start:.1f}s, compile: {t_compiled-t_loaded:.1f}s, inference: {time.time()-t_compiled:.1f}s)")
print(f"Total: {total:.1f}s (load: {t_loaded-t_start:.1f}s, compile: {t_compiled-t_loaded:.1f}s, infer: {time.time()-t_compiled:.1f}s)")
print(f"{'='*70}")
print()
print("KERNEL VERIFICATION: PASSED")
print("FMHA produces correct, finite output at hd=512, 128 query heads")
print("across all 61 layers. Garbage output is an architecture gap")
print("(missing mHC + MoE + KV cache), not a kernel issue.")
if __name__ == "__main__":