Add mHC (Manifold-Constrained Hyper-Connections) to single_shot
- Full mHC pre_block/post_block with Sinkhorn-Knopp normalization - Dynamic A_l (sigmoid), B_l (Birkhoff polytope), C_l (2*sigmoid) - Checkpoint: attn_hc.fn (24,28672) + base (24,) + scale (3,) - Two mHC blocks per layer: attn_hc + ffn_hc - Removed emergency RMSNorm — mHC handles normalization properly - X_l: (1, n_hc=4, H) residual state, init from embedding broadcast
This commit is contained in:
@@ -1,14 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Single-shot DSV4 inference — 8-GPU pipeline parallel.
|
||||
"""Single-shot DSV4 inference — 8-GPU pipeline parallel with mHC.
|
||||
|
||||
Loads the full NVFP4 checkpoint across 8 B200 GPUs (round-robin layer assignment).
|
||||
Each GPU holds ~8 layers of weights in HBM. Activations move between GPUs at
|
||||
layer boundaries via cudaMemcpy (fast on NVLink).
|
||||
Loads the full NVFP4 checkpoint across 8 B200 GPUs. Includes:
|
||||
- mHC (Manifold-Constrained Hyper-Connections) — load-bearing residual
|
||||
- Q low-rank projection + KV projection
|
||||
- RoPE (partial, last 64 dims)
|
||||
- Production FMHA kernel (tcgen05 MMA, hd=512, 128 heads)
|
||||
- Output projection: wo_a (grouped BMM) → wo_b (NVFP4)
|
||||
- Shared expert FFN (SwiGLU)
|
||||
- NVFP4 dequant → BF16 matmul baseline for linear layers
|
||||
|
||||
Pipeline:
|
||||
1. Load all 95 shards, assign each layer's weights to its GPU
|
||||
2. JIT-compile kernels (one-time)
|
||||
3. Decode loop: embed → layer 0 (gpu0) → layer 1 (gpu0) → ... → layer 8 (gpu1) → ... → norm → lm_head
|
||||
Missing (causing incorrect output):
|
||||
- Routed MoE experts (384 experts, top-6)
|
||||
- KV cache across decode steps
|
||||
- Compressor + indexer (CSA/HCA compressed KV)
|
||||
|
||||
Usage (on B200):
|
||||
source /root/dsv4-nvfp4-workspace/venv/bin/activate
|
||||
@@ -31,42 +36,127 @@ NUM_GPUS = 8
|
||||
FP4_LUT = torch.tensor([0., 2., 3., 4., 6., 8., 12., 24.])
|
||||
|
||||
def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2):
|
||||
"""Dequantize NVFP4 weight to BF16. All tensors must be on same device."""
|
||||
out_dim = weight.shape[0]
|
||||
in_packed = weight.shape[1]
|
||||
in_features = in_packed * 2
|
||||
|
||||
low = (weight & 0x0F).to(torch.int8)
|
||||
high = (weight >> 4).to(torch.int8)
|
||||
|
||||
low_sign = (low >> 3).bool()
|
||||
low_idx = (low & 0x07).long()
|
||||
high_sign = (high >> 3).bool()
|
||||
high_idx = (high & 0x07).long()
|
||||
|
||||
low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long()
|
||||
high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long()
|
||||
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
|
||||
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
|
||||
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)
|
||||
|
||||
w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features)
|
||||
|
||||
scale_f = weight_scale.float() * weight_scale_2.float()
|
||||
scale_expanded = scale_f.repeat_interleave(16, dim=1)
|
||||
|
||||
return (w_f * scale_expanded).bfloat16()
|
||||
|
||||
|
||||
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
|
||||
"""NVFP4 linear: dequant → BF16 matmul."""
|
||||
w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2)
|
||||
return torch.nn.functional.linear(x, w)
|
||||
|
||||
|
||||
def bf16_linear(x, weight):
|
||||
"""BF16 linear."""
|
||||
return torch.nn.functional.linear(x, weight.bfloat16())
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# mHC — Manifold-Constrained Hyper-Connections
|
||||
# =====================================================================
|
||||
|
||||
def sinkhorn_knopp(M, t_max=20, eps=1e-6):
|
||||
"""Project (T, n, n) positive matrices onto Birkhoff polytope."""
|
||||
for _ in range(t_max):
|
||||
M = M / (M.sum(dim=-1, keepdim=True) + eps)
|
||||
M = M / (M.sum(dim=-2, keepdim=True) + eps)
|
||||
return M
|
||||
|
||||
|
||||
class mHCBlock:
|
||||
"""One mHC block (attention or FFN).
|
||||
|
||||
Checkpoint weight mapping:
|
||||
fn: (24, 28672) FP32 = stacked [W_pre(4,K); W_res(16,K); W_post(4,K)]
|
||||
base: (24,) FP32 = bias, split as [S_pre(4); S_res(16); S_post(4)]
|
||||
scale: (3,) FP32 = [alpha_pre, alpha_res, alpha_post]
|
||||
"""
|
||||
def __init__(self, hidden_dim=7168, n_hc=4, t_max=20, device='cuda'):
|
||||
self.d = hidden_dim
|
||||
self.n_hc = n_hc
|
||||
self.K = n_hc * hidden_dim # 28672
|
||||
self.t_max = t_max
|
||||
self.device = device
|
||||
self.W_stacked = None # (24, K) FP32
|
||||
self.bias = None # (24,) FP32
|
||||
self.alphas = None # (3,) FP32
|
||||
|
||||
def load_from_checkpoint(self, fn, base, scale):
|
||||
"""Load from checkpoint tensors. All on target device, FP32."""
|
||||
self.W_stacked = fn.to(device=self.device, dtype=torch.float32).contiguous()
|
||||
self.bias = base.to(device=self.device, dtype=torch.float32).contiguous()
|
||||
self.alphas = scale.to(device=self.device, dtype=torch.float32).contiguous()
|
||||
|
||||
def _dynamic_params(self, X_l):
|
||||
"""Compute A_l, B_l, C_l from residual state.
|
||||
|
||||
X_l: (T, n_hc, d) BF16
|
||||
Returns: A_l (T, n_hc), B_l (T, n_hc, n_hc) FP32, C_l (T, n_hc)
|
||||
"""
|
||||
T, n, d = X_l.shape
|
||||
n_hc = self.n_hc
|
||||
|
||||
# Flatten and project with RMSNorm
|
||||
X_flat = X_l.reshape(T, self.K) # (T, K) BF16
|
||||
|
||||
# RMSNorm
|
||||
x_f32 = X_flat.float()
|
||||
rms = x_f32.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
x_normed = x_f32 * rms # (T, K) FP32
|
||||
|
||||
# Project: (T, K) @ (24, K)^T → (T, 24)
|
||||
proj = torch.nn.functional.linear(x_normed.bfloat16(), self.W_stacked.bfloat16()).float()
|
||||
proj = proj + self.bias.unsqueeze(0) # add bias
|
||||
|
||||
# Split into A, B, C
|
||||
i0, i1, i2 = n_hc, n_hc + n_hc * n_hc, 24
|
||||
A_raw = proj[:, :i0] # (T, 4)
|
||||
B_raw = proj[:, i0:i1] # (T, 16)
|
||||
C_raw = proj[:, i1:i2] # (T, 4)
|
||||
|
||||
# Split bias into S_pre, S_res, S_post
|
||||
S_pre = self.bias[:n_hc]
|
||||
S_res = self.bias[n_hc:n_hc + n_hc * n_hc]
|
||||
S_post = self.bias[n_hc + n_hc * n_hc:]
|
||||
|
||||
# Apply gating + biases
|
||||
a_pre, a_res, a_post = self.alphas[0], self.alphas[1], self.alphas[2]
|
||||
A_tilde = a_pre * A_raw + S_pre.unsqueeze(0)
|
||||
B_tilde = a_res * B_raw + S_res.unsqueeze(0)
|
||||
C_tilde = a_post * C_raw + S_post.unsqueeze(0)
|
||||
|
||||
# Constraints
|
||||
A_l = torch.sigmoid(A_tilde).bfloat16() # (T, 4) ∈ (0,1)
|
||||
C_l = (2.0 * torch.sigmoid(C_tilde)).bfloat16() # (T, 4) ∈ (0,2)
|
||||
B_exp = torch.exp(B_tilde).reshape(T, n_hc, n_hc) # (T, 4, 4)
|
||||
B_l = sinkhorn_knopp(B_exp, self.t_max) # FP32, doubly stochastic
|
||||
|
||||
return A_l, B_l, C_l
|
||||
|
||||
def pre_block(self, X_l):
|
||||
"""X_l: (T, n_hc, d) → x_in: (T, d), ctx"""
|
||||
A_l, B_l, C_l = self._dynamic_params(X_l)
|
||||
# x_in = A_l @ X_l: (T, 1, n_hc) bmm (T, n_hc, d) → (T, 1, d) → (T, d)
|
||||
x_in = torch.bmm(A_l.unsqueeze(1).float(), X_l.float()).squeeze(1).bfloat16()
|
||||
return x_in, (B_l, C_l)
|
||||
|
||||
def post_block(self, X_l, F_out, ctx):
|
||||
"""X_l: (T, n_hc, d), F_out: (T, d) → X_next: (T, n_hc, d)"""
|
||||
B_l, C_l = ctx
|
||||
# X_next = B_l @ X_l + C_l ⊗ F_out
|
||||
BX = torch.bmm(B_l, X_l.float()) # (T, n_hc, d) FP32
|
||||
CF = C_l.unsqueeze(-1).float() * F_out.unsqueeze(1).float() # (T, n_hc, d) FP32
|
||||
return (BX + CF).bfloat16()
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# RoPE
|
||||
# =====================================================================
|
||||
@@ -74,110 +164,73 @@ def bf16_linear(x, weight):
|
||||
def build_rope_cache(max_pos, head_dim, rope_dim, device, theta=10000.0):
|
||||
half = rope_dim // 2
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
|
||||
positions = torch.arange(max_pos, dtype=torch.float32)
|
||||
angles = torch.outer(positions, freqs)
|
||||
angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs)
|
||||
return torch.cos(angles).to(device), torch.sin(angles).to(device)
|
||||
|
||||
|
||||
def apply_rope(x, positions, cos_cache, sin_cache, rope_dim):
|
||||
T, n_h, hd = x.shape
|
||||
nope = hd - rope_dim
|
||||
half = rope_dim // 2
|
||||
|
||||
cos = cos_cache[positions].unsqueeze(1).to(x.dtype)
|
||||
sin = sin_cache[positions].unsqueeze(1).to(x.dtype)
|
||||
|
||||
x_rope = x[:, :, nope:]
|
||||
even = x_rope[:, :, 0::2]
|
||||
odd = x_rope[:, :, 1::2]
|
||||
|
||||
out = x.clone()
|
||||
out[:, :, nope:][..., 0::2] = even * cos - odd * sin
|
||||
out[:, :, nope:][..., 1::2] = even * sin + odd * cos
|
||||
out[:, :, nope:][..., 0::2] = x[:, :, nope:][..., 0::2] * cos - x[:, :, nope:][..., 1::2] * sin
|
||||
out[:, :, nope:][..., 1::2] = x[:, :, nope:][..., 0::2] * sin + x[:, :, nope:][..., 1::2] * cos
|
||||
return out
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Checkpoint loading — load all shards, group by layer, assign to GPU
|
||||
# Weight loading
|
||||
# =====================================================================
|
||||
|
||||
def load_all_weights(checkpoint_dir, num_layers):
|
||||
"""Load all 95 shards and organize weights by layer, moving to target GPU.
|
||||
|
||||
Returns: dict mapping layer_idx → dict of weight tensors (on target GPU)
|
||||
plus global weights (embed, norm, lm_head) on gpu0
|
||||
"""
|
||||
from safetensors.torch import load_file
|
||||
from collections import defaultdict
|
||||
|
||||
cdir = Path(checkpoint_dir)
|
||||
|
||||
# Load the index for fast shard lookup
|
||||
index_path = cdir / "model.safetensors.index.json"
|
||||
weight_map = {}
|
||||
if index_path.exists():
|
||||
with open(index_path) as f:
|
||||
weight_map = json.load(f).get("weight_map", {})
|
||||
else:
|
||||
weight_map = {}
|
||||
|
||||
# Organize: which shard files do we need?
|
||||
shard_names = set(weight_map.values()) if weight_map else {
|
||||
f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96)
|
||||
}
|
||||
|
||||
# Load all shards (one at a time to limit CPU RAM)
|
||||
print(f"Loading {len(shard_names)} shards from checkpoint...")
|
||||
all_weights = {} # key → tensor (CPU)
|
||||
print(f"Loading {len(shard_names)} shards...")
|
||||
all_weights = {}
|
||||
loaded = 0
|
||||
for shard_name in sorted(shard_names):
|
||||
shard_path = cdir / shard_name
|
||||
if not shard_path.exists():
|
||||
if not (cdir / shard_name).exists():
|
||||
continue
|
||||
data = load_file(str(shard_path))
|
||||
data = load_file(str(cdir / shard_name))
|
||||
all_weights.update(data)
|
||||
loaded += 1
|
||||
if loaded % 10 == 0:
|
||||
print(f" Loaded {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors")
|
||||
if loaded % 20 == 0:
|
||||
print(f" {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors")
|
||||
print(f" Done: {len(all_weights)} tensors")
|
||||
|
||||
print(f" Done: {len(all_weights)} tensors loaded to CPU")
|
||||
layer_weights = {}
|
||||
global_weights = {}
|
||||
|
||||
# Group by layer and assign to GPU
|
||||
# Layer i goes to GPU (i // 8) — 8 layers per GPU
|
||||
# Actually: round-robin — layer i goes to GPU (i % 8)
|
||||
# This balances load better (8 layers per GPU, ~15GB each)
|
||||
|
||||
layer_weights = {} # layer_idx → dict of tensors on target GPU
|
||||
global_weights = {} # embed, norm, lm_head → on gpu0
|
||||
|
||||
print("Assigning layers to GPUs...")
|
||||
print("Assigning to GPUs...")
|
||||
for key, tensor in all_weights.items():
|
||||
# Determine which layer this weight belongs to
|
||||
if key.startswith("model.layers."):
|
||||
parts = key.split(".")
|
||||
layer_idx = int(parts[2])
|
||||
target_gpu = layer_idx % NUM_GPUS
|
||||
|
||||
if layer_idx not in layer_weights:
|
||||
layer_weights[layer_idx] = {"_device": f"cuda:{target_gpu}"}
|
||||
|
||||
# Move to target GPU
|
||||
layer_weights[layer_idx][key] = tensor.to(f"cuda:{target_gpu}")
|
||||
|
||||
li = int(key.split(".")[2])
|
||||
target_gpu = li % NUM_GPUS
|
||||
target_device = f"cuda:{target_gpu}"
|
||||
if li not in layer_weights:
|
||||
layer_weights[li] = {"_device": target_device, "_gpu": target_gpu}
|
||||
layer_weights[li][key] = tensor.to(target_device)
|
||||
elif key.startswith("model.embed_tokens"):
|
||||
global_weights[key] = tensor.to("cuda:0")
|
||||
|
||||
elif key.startswith("model.norm"):
|
||||
global_weights[key] = tensor.to("cuda:0")
|
||||
|
||||
elif key.startswith("lm_head"):
|
||||
global_weights[key] = tensor.to("cuda:0")
|
||||
|
||||
# Print per-GPU memory usage
|
||||
for gpu in range(NUM_GPUS):
|
||||
torch.cuda.set_device(gpu)
|
||||
allocated = torch.cuda.memory_allocated(gpu) / 1e9
|
||||
reserved = torch.cuda.memory_reserved(gpu) / 1e9
|
||||
print(f" GPU {gpu}: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved")
|
||||
alloc = torch.cuda.memory_allocated(gpu) / 1e9
|
||||
print(f" GPU {gpu}: {alloc:.1f}GB")
|
||||
|
||||
return layer_weights, global_weights
|
||||
|
||||
@@ -186,95 +239,95 @@ def load_all_weights(checkpoint_dir, num_layers):
|
||||
# Single layer forward
|
||||
# =====================================================================
|
||||
|
||||
def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
|
||||
"""Forward one layer. x: (1, hidden) BF16 → (1, hidden) BF16."""
|
||||
device = x.device
|
||||
def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc):
|
||||
"""Forward one layer with mHC.
|
||||
|
||||
X_l: (1, n_hc, H) BF16 → (1, n_hc, H) BF16
|
||||
"""
|
||||
device = X_l.device
|
||||
H = cfg["hidden_size"]
|
||||
n_h = cfg["num_attention_heads"]
|
||||
hd = cfg["head_dim"]
|
||||
rd = cfg["qk_rope_head_dim"]
|
||||
o_rank = cfg["o_lora_rank"]
|
||||
o_groups = cfg["o_groups"]
|
||||
n_hc = 4
|
||||
|
||||
pre = f"model.layers.{li}.self_attn"
|
||||
T = x.shape[0]
|
||||
heads_per_group = n_h // o_groups # 8
|
||||
group_input_dim = heads_per_group * hd # 4096
|
||||
T = X_l.shape[0]
|
||||
heads_per_group = n_h // o_groups
|
||||
group_input_dim = heads_per_group * hd
|
||||
|
||||
# ---- Q projection: q_a (down) → q_b (up) ----
|
||||
c_Q = nvfp4_linear(x, w[f"{pre}.q_a_proj.weight"],
|
||||
# ==== mHC pre_block (attention) ====
|
||||
x_in, attn_ctx = attn_mhc.pre_block(X_l) # x_in: (T, H) BF16
|
||||
|
||||
# ==== Q projection ====
|
||||
c_Q = nvfp4_linear(x_in, w[f"{pre}.q_a_proj.weight"],
|
||||
w[f"{pre}.q_a_proj.weight_scale"],
|
||||
w[f"{pre}.q_a_proj.weight_scale_2"])
|
||||
q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"],
|
||||
w[f"{pre}.q_b_proj.weight_scale"],
|
||||
w[f"{pre}.q_b_proj.weight_scale_2"])
|
||||
|
||||
# ---- KV projection ----
|
||||
kv = nvfp4_linear(x, w[f"{pre}.kv_proj.weight"],
|
||||
# ==== KV projection ====
|
||||
kv = nvfp4_linear(x_in, w[f"{pre}.kv_proj.weight"],
|
||||
w[f"{pre}.kv_proj.weight_scale"],
|
||||
w[f"{pre}.kv_proj.weight_scale_2"])
|
||||
|
||||
# ---- Reshape for attention ----
|
||||
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd)
|
||||
k = kv.reshape(T, 1, hd).permute(1, 0, 2) # (1, T, hd) — MQA
|
||||
# ==== Reshape for attention ====
|
||||
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2)
|
||||
k = kv.reshape(T, 1, hd).permute(1, 0, 2)
|
||||
v = k.clone()
|
||||
|
||||
# ---- Apply RoPE ----
|
||||
# ==== RoPE ====
|
||||
pos = torch.tensor([0], dtype=torch.long, device=device)
|
||||
q_heads = apply_rope(q_heads, pos, rope_cos, rope_sin, rd)
|
||||
k = apply_rope(k, pos, rope_cos, rope_sin, rd)
|
||||
|
||||
# ---- FMHA ----
|
||||
# ==== FMHA ====
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
attn_out = dsv4_attention(q_heads, k, v) # (n_h, T, hd)
|
||||
attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd) # (T, n_h*hd)
|
||||
|
||||
# ---- Output projection: wo_a (BF16 grouped BMM) → wo_b (NVFP4) ----
|
||||
attn_grouped = attn_out.reshape(T, o_groups, heads_per_group, hd)
|
||||
attn_grouped = attn_grouped.reshape(T, o_groups, group_input_dim)
|
||||
attn_out = dsv4_attention(q_heads, k, v)
|
||||
attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd)
|
||||
|
||||
# ==== Output projection ====
|
||||
attn_grouped = attn_out.reshape(T, o_groups, heads_per_group, hd).reshape(T, o_groups, group_input_dim)
|
||||
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16()
|
||||
oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim)
|
||||
|
||||
attn_for_bmm = attn_grouped.permute(1, 0, 2) # (16, 1, 4096)
|
||||
grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) # (16, 1, o_rank)
|
||||
attn_for_bmm = attn_grouped.permute(1, 0, 2)
|
||||
grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2))
|
||||
grouped_flat = grouped_out.permute(1, 0, 2).reshape(T, o_groups * o_rank)
|
||||
F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"],
|
||||
w[f"{pre}.o_b_proj.weight_scale"],
|
||||
w[f"{pre}.o_b_proj.weight_scale_2"])
|
||||
|
||||
attn_proj = nvfp4_linear(grouped_flat,
|
||||
w[f"{pre}.o_b_proj.weight"],
|
||||
w[f"{pre}.o_b_proj.weight_scale"],
|
||||
w[f"{pre}.o_b_proj.weight_scale_2"])
|
||||
# ==== mHC post_block (attention) ====
|
||||
X_l = attn_mhc.post_block(X_l, F_attn, attn_ctx) # (T, n_hc, H)
|
||||
|
||||
# ---- Residual + emergency RMSNorm (mHC missing) ----
|
||||
x = x + attn_proj
|
||||
x = x.clamp(-65504, 65504)
|
||||
xf = x.float()
|
||||
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
x = (xf * rms).bfloat16()
|
||||
# ==== mHC pre_block (FFN) ====
|
||||
x_ffn, ffn_ctx = ffn_mhc.pre_block(X_l)
|
||||
|
||||
# ---- FFN: shared expert ----
|
||||
# ==== FFN: shared expert ====
|
||||
se_pre = f"model.layers.{li}.mlp.shared_experts"
|
||||
se_gate_w = w.get(f"{se_pre}.gate_proj.weight")
|
||||
F_ffn = torch.zeros_like(x_ffn)
|
||||
if se_gate_w is not None:
|
||||
gate = nvfp4_linear(x, se_gate_w,
|
||||
gate = nvfp4_linear(x_ffn, se_gate_w,
|
||||
w[f"{se_pre}.gate_proj.weight_scale"],
|
||||
w[f"{se_pre}.gate_proj.weight_scale_2"])
|
||||
up = nvfp4_linear(x, w[f"{se_pre}.up_proj.weight"],
|
||||
up = nvfp4_linear(x_ffn, w[f"{se_pre}.up_proj.weight"],
|
||||
w[f"{se_pre}.up_proj.weight_scale"],
|
||||
w[f"{se_pre}.up_proj.weight_scale_2"])
|
||||
ffn_out = nvfp4_linear(
|
||||
F_ffn = nvfp4_linear(
|
||||
torch.nn.functional.silu(gate) * up,
|
||||
w[f"{se_pre}.down_proj.weight"],
|
||||
w[f"{se_pre}.down_proj.weight_scale"],
|
||||
w[f"{se_pre}.down_proj.weight_scale_2"],
|
||||
)
|
||||
x = x + ffn_out
|
||||
x = x.clamp(-65504, 65504)
|
||||
xf = x.float()
|
||||
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
x = (xf * rms).bfloat16()
|
||||
|
||||
return x
|
||||
# ==== mHC post_block (FFN) ====
|
||||
X_l = ffn_mhc.post_block(X_l, F_ffn, ffn_ctx)
|
||||
|
||||
return X_l
|
||||
|
||||
|
||||
# =====================================================================
|
||||
@@ -284,10 +337,9 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
|
||||
def main():
|
||||
t_start = time.time()
|
||||
print("=" * 70)
|
||||
print("DSV4 Single-Shot Inference — 8-GPU Pipeline Parallel")
|
||||
print("DSV4 Single-Shot Inference — 8-GPU with mHC")
|
||||
print("=" * 70)
|
||||
|
||||
# Config
|
||||
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
||||
cfg = json.load(f)
|
||||
n_layers = cfg["num_hidden_layers"]
|
||||
@@ -295,58 +347,69 @@ def main():
|
||||
n_h = cfg["num_attention_heads"]
|
||||
hd = cfg["head_dim"]
|
||||
rd = cfg["qk_rope_head_dim"]
|
||||
n_hc = 4
|
||||
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
|
||||
print(f"GPUs: {NUM_GPUS}, ~{n_layers // NUM_GPUS} layers per GPU")
|
||||
|
||||
# ---- Phase 1: Load weights ----
|
||||
print(f"\n{'='*70}")
|
||||
print("Phase 1: Loading weights across 8 GPUs")
|
||||
print(f"{'='*70}")
|
||||
# ==== Phase 1: Load weights ====
|
||||
print(f"\n{'='*70}\nPhase 1: Loading weights\n{'='*70}")
|
||||
layer_weights, global_weights = load_all_weights(CHECKPOINT_DIR, n_layers)
|
||||
t_loaded = time.time()
|
||||
print(f"Weight loading: {t_loaded - t_start:.1f}s")
|
||||
|
||||
# Set default device to gpu0 for embed/lm_head
|
||||
torch.cuda.set_device(0)
|
||||
# ==== Build mHC blocks per layer ====
|
||||
print("Building mHC blocks...")
|
||||
attn_mhc_blocks = {}
|
||||
ffn_mhc_blocks = {}
|
||||
for li in range(n_layers):
|
||||
gpu = li % NUM_GPUS
|
||||
dev = f"cuda:{gpu}"
|
||||
|
||||
# Attention mHC
|
||||
attn_fn = layer_weights[li].get(f"model.layers.{li}.attn_hc.fn")
|
||||
attn_base = layer_weights[li].get(f"model.layers.{li}.attn_hc.base")
|
||||
attn_scale = layer_weights[li].get(f"model.layers.{li}.attn_hc.scale")
|
||||
if attn_fn is not None and attn_base is not None and attn_scale is not None:
|
||||
attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
||||
attn_mhc.load_from_checkpoint(attn_fn, attn_base, attn_scale)
|
||||
attn_mhc_blocks[li] = attn_mhc
|
||||
|
||||
# FFN mHC
|
||||
ffn_fn = layer_weights[li].get(f"model.layers.{li}.ffn_hc.fn")
|
||||
ffn_base = layer_weights[li].get(f"model.layers.{li}.ffn_hc.base")
|
||||
ffn_scale = layer_weights[li].get(f"model.layers.{li}.ffn_hc.scale")
|
||||
if ffn_fn is not None and ffn_base is not None and ffn_scale is not None:
|
||||
ffn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
||||
ffn_mhc.load_from_checkpoint(ffn_fn, ffn_base, ffn_scale)
|
||||
ffn_mhc_blocks[li] = ffn_mhc
|
||||
|
||||
# Embedding on gpu0
|
||||
print(f" attn mHC: {len(attn_mhc_blocks)} layers")
|
||||
print(f" ffn mHC: {len(ffn_mhc_blocks)} layers")
|
||||
|
||||
# ==== Global weights (gpu0) ====
|
||||
torch.cuda.set_device(0)
|
||||
embed_w = global_weights.get("model.embed_tokens.weight")
|
||||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16())
|
||||
|
||||
# lm_head on gpu0
|
||||
lm_w = global_weights.get("lm_head.weight", embed_w).bfloat16()
|
||||
|
||||
# Final norm on gpu0
|
||||
final_norm_w = global_weights.get("model.norm.weight")
|
||||
|
||||
# Build RoPE caches — one per GPU
|
||||
rope_caches = {}
|
||||
for gpu in range(NUM_GPUS):
|
||||
rope_caches[gpu] = build_rope_cache(8192, hd, rd, f"cuda:{gpu}")
|
||||
# RoPE caches per GPU
|
||||
rope_caches = {g: build_rope_cache(8192, hd, rd, f"cuda:{g}") for g in range(NUM_GPUS)}
|
||||
|
||||
# ---- Phase 2: JIT compile kernels ----
|
||||
print(f"\n{'='*70}")
|
||||
print("Phase 2: JIT compiling kernels")
|
||||
print(f"{'='*70}")
|
||||
# Trigger FMHA kernel compile on gpu0 with a dummy forward
|
||||
# This compiles the C API .so and caches it for all subsequent calls
|
||||
# ==== Phase 2: Compile ====
|
||||
print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}")
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
dummy_q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda:0')
|
||||
dummy_k = torch.randn(1, 1, hd, dtype=torch.bfloat16, device='cuda:0')
|
||||
dummy_v = dummy_k.clone()
|
||||
try:
|
||||
_ = dsv4_attention(dummy_q, dummy_k, dummy_v)
|
||||
print(" FMHA kernel: compiled OK")
|
||||
_ = dsv4_attention(dummy_q, dummy_k, dummy_k.clone())
|
||||
print(" FMHA: compiled OK")
|
||||
except Exception as e:
|
||||
print(f" FMHA kernel compile error: {e}")
|
||||
print(f" FMHA error: {e}")
|
||||
t_compiled = time.time()
|
||||
print(f"Kernel compilation: {t_compiled - t_loaded:.1f}s")
|
||||
|
||||
# ---- Phase 3: Inference ----
|
||||
print(f"\n{'='*70}")
|
||||
print("Phase 3: Inference")
|
||||
print(f"{'='*70}")
|
||||
print(f"Compile: {t_compiled - t_loaded:.1f}s")
|
||||
|
||||
# ==== Phase 3: Inference ====
|
||||
print(f"\n{'='*70}\nPhase 3: Inference\n{'='*70}")
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||||
input_ids = tokenizer.encode(PROMPT, return_tensors="pt").cuda()
|
||||
@@ -358,60 +421,59 @@ def main():
|
||||
t0 = time.time()
|
||||
tid = torch.tensor([generated[-1]], dtype=torch.long, device='cuda:0')
|
||||
|
||||
# Embed (gpu0)
|
||||
x = embed(tid) # (1, H) on gpu0
|
||||
# Embed → mHC init state
|
||||
emb = embed(tid) # (1, H) on gpu0
|
||||
X = emb.unsqueeze(1).expand(-1, n_hc, -1).clone() # (1, n_hc, H)
|
||||
|
||||
# Process layers — move x to the right GPU for each layer
|
||||
# Process layers
|
||||
for li in range(n_layers):
|
||||
target_gpu = li % NUM_GPUS
|
||||
target_device = f"cuda:{target_gpu}"
|
||||
gpu = li % NUM_GPUS
|
||||
target_device = f"cuda:{gpu}"
|
||||
if X.device != torch.device(target_device):
|
||||
X = X.to(target_device)
|
||||
torch.cuda.set_device(gpu)
|
||||
|
||||
# Move activation to layer's GPU and set as active device
|
||||
if x.device != torch.device(target_device):
|
||||
x = x.to(target_device)
|
||||
torch.cuda.set_device(target_gpu)
|
||||
|
||||
lw = layer_weights[li]
|
||||
rc, rs = rope_caches[target_gpu]
|
||||
x = forward_layer(x, lw, li, cfg, rc, rs)
|
||||
attn_mhc = attn_mhc_blocks.get(li)
|
||||
ffn_mhc = ffn_mhc_blocks.get(li)
|
||||
rc, rs = rope_caches[gpu]
|
||||
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs, attn_mhc, ffn_mhc)
|
||||
|
||||
# Move back to gpu0 for final norm + lm_head
|
||||
x = x.to('cuda:0')
|
||||
# Back to gpu0
|
||||
X = X.to('cuda:0')
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# Final norm
|
||||
# Read out stream 0 → RMSNorm → lm_head
|
||||
x_out = X[:, 0, :] # (1, H)
|
||||
if final_norm_w is not None:
|
||||
xf = x.float()
|
||||
xf = x_out.float()
|
||||
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
x = (xf * rms * final_norm_w.float()).bfloat16()
|
||||
x_out = (xf * rms * final_norm_w.float()).bfloat16()
|
||||
|
||||
# lm_head
|
||||
logits = torch.nn.functional.linear(x, lm_w)
|
||||
logits = torch.nn.functional.linear(x_out, lm_w)
|
||||
next_id = torch.argmax(logits, dim=-1).item()
|
||||
generated.append(next_id)
|
||||
|
||||
tok_str = tokenizer.decode([next_id])
|
||||
dt = time.time() - t0
|
||||
if step == 0:
|
||||
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) [first step includes compile cache]")
|
||||
else:
|
||||
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s)")
|
||||
|
||||
# Check for NaN
|
||||
has_nan = torch.isnan(logits.float()).any().item()
|
||||
logit_range = f"[{logits.float().min().item():.1f}, {logits.float().max().item():.1f}]"
|
||||
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) logits={logit_range} nan={has_nan}")
|
||||
|
||||
if has_nan:
|
||||
print(" NaN detected, stopping")
|
||||
break
|
||||
if next_id == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
# ---- Output ----
|
||||
out = tokenizer.decode(generated, skip_special_tokens=True)
|
||||
total_time = time.time() - t_start
|
||||
total = time.time() - t_start
|
||||
print(f"\n{'='*70}")
|
||||
print(f"Input: '{PROMPT}'")
|
||||
print(f"Output: '{out}'")
|
||||
print(f"Total time: {total_time:.1f}s (load: {t_loaded-t_start:.1f}s, compile: {t_compiled-t_loaded:.1f}s, inference: {time.time()-t_compiled:.1f}s)")
|
||||
print(f"Total: {total:.1f}s (load: {t_loaded-t_start:.1f}s, compile: {t_compiled-t_loaded:.1f}s, infer: {time.time()-t_compiled:.1f}s)")
|
||||
print(f"{'='*70}")
|
||||
print()
|
||||
print("KERNEL VERIFICATION: PASSED")
|
||||
print("FMHA produces correct, finite output at hd=512, 128 query heads")
|
||||
print("across all 61 layers. Garbage output is an architecture gap")
|
||||
print("(missing mHC + MoE + KV cache), not a kernel issue.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user