9 Commits

Author SHA1 Message Date
4fe73fe713 auto: pre-test commit 2026-06-03 15:45:15 +00:00
f577ed97f4 Fix: Use PyTorch dequant_nvfp4 for weight dequantization (compressor/indexer/router gate)
The CUDA dequantize_nvfp4 (dsv4/ops/quantize.py) was designed for
activations/KV and assumes row-major (M, N/16) scale layout. Using it
for weight dequantization caused async illegal memory access because
weight scales don't match the kernel's expected layout. The kernel only
validates row count, not width or contiguity.

All 4 call sites now use the PyTorch dequant_nvfp4 (defined in
single_shot_inference.py) which handles weight_scale_2 and input_scale
correctly and cannot cause OOB access:
- Compressor.load: kv_proj, gate_proj
- Indexer.load: weights_proj
- Router gate dequantization in main()
2026-06-03 14:57:40 +00:00
1121cd7b47 Add CUDA_LAUNCH_BLOCKING=1 to catch async errors 2026-06-03 14:48:51 +00:00
f3bb0ca08c Fix dequant gsa: use ws2 only, NOT input_scale * ws2
For weight dequantization, gsa should be weight_scale_2 only.
input_scale is the activation global scale — it belongs on the GEMM's
activation side, not the weight side. Using input_scale * ws2 gave
gsa = 6e-8 (essentially zero), making dequantized weights ~0.

The GEMM formula is y = (x * scale_a * gsa) @ (w * scale_b * gsb)
where gsb = input_scale * ws2. But dequantize_nvfp4 is just the
weight half: w_bf16 = lut[w] * block_scale * ws2.
2026-06-03 14:38:24 +00:00
470e65fb19 Fix dequant gsb: input_scale * ws2, not 1.0 * ws2
The NVFP4 dequantize formula is w = lut[w_packed] * scale * ws2,
and in the GEMM the global_scale_b = input_scale * ws2. Was incorrectly
using gsb = 1.0 * ws2 (missing input_scale). This would produce
wrongly-scaled BF16 weights from dequantize_nvfp4.
2026-06-03 14:26:59 +00:00
2dd16d5789 Switch compressor + indexer weights_proj to BF16 F.linear
Only the CSA indexer QK path (q_b_proj) is explicitly FP4-QATed.
The rest of the compressor/indexer projections are NOT, so use BF16:

- Compressor kv_proj, gate_proj: dequantize NVFP4 → BF16, F.linear
- Indexer weights_proj: dequantize NVFP4 → BF16, F.linear
- Indexer q_b_proj: KEEP as NVFP4 (this IS the FP4-QATed path)
- Indexer compressor: inherits Compressor's BF16 path
2026-06-03 14:19:41 +00:00
95e45a87e3 Add explicit .to(dev) on W_gate after transpose — belt and suspenders 2026-06-03 14:17:02 +00:00
ef94c48957 Simplify router gate: dequant NVFP4 → BF16, F.linear (no FP8 middleman)
Same as what worked before. The checkpoint stores NVFP4 weights, so we
dequantize once at load time and use cuBLAS F.linear. No FP8 re-quantize
step needed — that was just adding noise on top of the NVFP4 dequant.
2026-06-03 14:14:10 +00:00
715602c87c Switch lm_head to BF16 + router gate to FP8_E4M3
lm_head: BF16 F.linear (checkpoint weight is BF16, no quantization)
Router gate: FP8_E4M3 quantize→dequantize round-trip, then F.linear
- Dequantize NVFP4 checkpoint weights to BF16 first
- Quantize to FP8_E4M3 (scale = amax/448)
- Dequantize back to BF16 for F.linear
- Uses BF16 dispatch path in dense_router_dispatch
- Simpler scale wiring than NVFP4 (single per-tensor scale)
2026-06-03 14:10:28 +00:00
3 changed files with 208 additions and 77 deletions

View File

@@ -0,0 +1,94 @@
# DSV4 → vLLM: CUDA-Graph Safety / GPU-Native Requirements (PART 2 companion)
**Goal:** the per-step decode forward must be fully GPU-native so vLLM can capture and replay it. No implicit device→host sync, no host control flow that reads a device value, no data-dependent shapes, no per-step host allocation. This doc gives you (A) a detector so you find every violation *once, upfront*, (B) the exhaustive hidden-CPU checklist, and (C) the DSV4-specific kernels that must be device-native.
## The one rule that decides everything
Branching on a **host-known integer** (step number, position, batch size, dtype, static shape) is graph-compatible — you capture one graph per bucket and the scheduler picks by that integer. Branching on a **device value** (sampled token, per-expert token count, top-k result, a mask, a norm/residual magnitude) is **not** — it must become device-side, fixed-shape work with masking. Every violation below is a place something reads a device value on the host.
You do **not** need one monolithic graph. The standard pattern (what vLLM's DSV4 does) is *bucket by shape + break at attention + keep the dense parts captured.* Your job is to make each dynamic decision either device-side or isolated to that eager break.
---
## SECTION A — The detector (build this FIRST, before porting anything)
Stop hunting syncs by hand. Make them fail at the exact line:
```python
import torch
torch.cuda.set_sync_debug_mode("error") # raises at any implicit device→host sync
# ... run one decode step of the forward ...
torch.cuda.set_sync_debug_mode("default")
```
And a capture-under-test (most illegal host ops error *during* capture):
```python
g = torch.cuda.CUDAGraph()
# static input buffers allocated ONCE, outside capture:
with torch.cuda.graph(g):
out = decode_step(static_inputs) # capture fails loudly on .item(), sync, alloc, etc.
for _ in range(3):
static_inputs.copy_(next_inputs); g.replay() # replay must reproduce eager output
```
**Do this on the current `single_shot` forward first** — it inventories *every* existing sync in one pass, so you get the whole hunt-list upfront instead of discovering them one at a time during vLLM bring-up. Then gate every commit on both checks in CI; the day someone adds a `.item()`, the build fails at that line.
Also useful: `compute-sanitizer --tool synccheck`, and `nsys` to eyeball CPU↔GPU stall gaps.
---
## SECTION B — The hidden-CPU checklist (grep the hot path for these)
**Explicit device→host transfers**
`.item()` · `.cpu()` · `.tolist()` · `.numpy()` · `int(t)` / `float(t)` / `bool(t)` · `print(t)` · f-strings/logging that interpolate a tensor · `assert (device_condition)` (e.g. `assert (x>0).all()`) · `.to("cpu")`
**Host control flow on device values**
`if t:` · `if mask.any():` · `if x.sum() > thr:` · `while t > 0:` · `for i in range(n.item())` · convergence early-exit reading a device residual · choosing a kernel based on the sampled token
**Data-dependent shapes (these both change shape AND sync)**
`torch.nonzero` · `torch.where(cond)` (one-arg form) · `torch.unique` · `torch.bincount` (when it drives a shape) · boolean/mask indexing `x[mask]`, `x[x>0]` · `masked_select` · `reshape(n.item(), ...)` · any gather sized by a device-computed count
**Per-step host allocation**
`torch.empty/zeros/tensor([...])` created fresh inside the captured region · building a Python list then `torch.tensor(list, device=...)` · `np.*` anywhere on the path · any CPU tensor then `.to(device)` per step
**Host RNG**
`random.*` / `np.random.*` / Python rng for sampling → use a device generator / captured philox state
**Sync primitives & checks**
`torch.cuda.synchronize()` · `stream.synchronize()` · `torch.isnan(x).any()` / `isinf(...).any()` debug guards · pinned-copy syncs
**Sneaky ones (the "didn't realize" category)**
`sum(t)` / `min(t)` / `max(t)` (Python builtins iterate → sync; use `t.sum()`) · a `.cpu()`/`.item()` hidden inside a logging, assert, or metrics helper · `einops` rearrange with a data-dependent dim · telemetry/progress hooks that read tensors · indexing a tensor with a Python int derived from `.item()`
**What is FINE (no sync, don't waste time on these)**
`.shape` / `.size()` / `.numel()` / `.dtype` (host metadata, no sync) · branching on host-known ints (step/batch/static shape) · dtype/shape kernel dispatch · the **stop-token check, detokenize, and your BF16 precision-floor dequant** (all load-time or *outside* the captured graph — leave them on host, that's correct).
---
## SECTION C — DSV4-specific kernels that must be GPU-native
| # | Hazard (current host/dynamic behavior) | Requirement | vLLM reference |
|---|---|---|---|
| 1 | Compressor returns `None` for 3/4 (CSA) or 127/128 (HCA) decode steps — periodic host branch | Compress **every** step into a persistent partial-state/ring buffer; emit the compressed entry **device-side** on the boundary | `save_partial_states`, `fused_compress_quant_cache` |
| 2 | KV grows each step → attention shape changes | Paged KV (fixed blocks + block table) captured at fixed max-len with masking, **or** make attention the eager break | `breakable_cudagraph` / `eager_break_during_capture`; `AttentionCGSupport.ALWAYS` |
| 3 | Indexer top-k → host reads selected count to size gather | Always gather fixed `k` (padded), mask invalid; no host read of the count | `dequant_gather_k_cutedsl` (fixed-shape gather) |
| 4 | MoE top-6 → per-expert token counts drive per-expert launches | Routing permutation/offsets computed **on device**; grouped GEMM with device offsets and a fixed total launch | `prepare_megamoe` |
| 5 | Next token / positions managed on host, fresh tensors per step | Static I/O buffers allocated once; **in-place** `copy_` of next token; positions via device-side increment (or per-shape bucketed graphs) | vLLM persistent input buffers |
Also confirm:
- **Sinkhorn** runs a **fixed 20 iterations with no host convergence check** (a `while not converged` reading a device residual breaks capture). Fixed-iteration = safe.
- **Sampler** is device-side; `repetition_penalty` reads from a **fixed-size device** recent-token buffer (not a growing Python list); the EOS/stop decision is a host step **outside** the graph (correct).
---
## SECTION D — Integration order
1. **Build Section A's detector and run it on the current forward** — get the full sync inventory in one pass.
2. Fix Section C's five device-native kernels (these are the structural ones; the rest of Section B tends to be incidental `.item()`s once these are right).
3. Re-run capture-under-test until it captures clean and replay matches eager bit-for-bit.
4. Gate every commit on the capture test so violations can never silently return.
## Guardrails
- Keep the stop-check, detokenize, and load-time BF16 dequant on the host — they're outside the captured region by design; don't contort them to be "graph-safe."
- Decide the attention model up front (paged-capturable vs eager-break) — retrofitting it later forces a KV-cache rewrite.
- Host-known-int branching is allowed; only device-value branching must be eliminated. Don't over-correct and try to make legitimate shape/dtype dispatch device-side.

View File

@@ -0,0 +1,69 @@
# DSV4 Precision Floor — PyTorch Validation (PART 1) + Native Port (PART 2)
**What we learned:** the NVFP4 precision floor for this model is — keep **LM head** BF16, **router gate** BF16, and the **compressor/indexer helper projections** BF16, with the **one exception** that the **CSA indexer QK path stays FP4** (it was explicitly FP4-QATed; the other compressor projections were not, so PTQ-ing them to FP4 breaks). We validated each individually. Now do all of them together, simple-PyTorch first, then native.
---
## ⚠️ First: the CUDA illegal-memory-access (you're calling the wrong dequant)
There are **two** functions with nearly the same name:
- `single_shot_inference.py:238``dequant_nvfp4(weight, weight_scale, weight_scale_2, input_scale)`**pure PyTorch** (does `weight_scale.repeat_interleave(16,1) * scales`). This is what `nvfp4_linear_ref` uses — your **validated reference**. It cannot cause an illegal access.
- `dsv4/ops/quantize.py:377``dequantize_nvfp4(x_fp4, x_sf, gsa)` — calls the **CUDA kernel** `dequant_nvfp4.cu`. **This is the one crashing.**
The precision-floor code (lines 328 / 333 / 426: kv_proj, gate_proj, wp) imports the **CUDA** one and feeds it **weights**. But that kernel was written for the **activation / KV-gather** path — read its own docstring: *"compressed KV is stored as NVFP4, dequantized on-the-fly."* It assumes row-major `(M, N/16)` block scales, per-row `gsa`, `N=512`.
The host wrapper only does `TORCH_CHECK(sf_data.size(0) == M)` — it validates the scale's **row count and nothing else** (not width, not total size, not contiguity). The kernel then indexes `sf_data[m*(N/16) + n_block]` flat. For a weight whose scale isn't *exactly* contiguous row-major `(M, N/16)` — different width, padding, non-contiguous `.to(dev)` view, or the GEMM swizzle — that index walks off the allocation → **async illegal access, surfacing at the next sync (the compressor load).** The activation/KV path never tripped it because those scales already match the assumed layout.
**Confirm it in 2 minutes** (the error is async, so do this to localize it):
```bash
compute-sanitizer --tool memcheck <your harness> ... # will name dequant_nvfp4_kernel + the sf_data read
# or: CUDA_LAUNCH_BLOCKING=1 to move the report to the offending launch
```
And add these guards to `dequant_nvfp4_cuda` in `dequant_nvfp4.cu` — they turn the async crash into an immediate, located error and print the size mismatch:
```cpp
TORCH_CHECK(fp4_data.is_contiguous() && sf_data.is_contiguous(), "dequant inputs must be contiguous");
TORCH_CHECK(sf_data.numel() >= (int64_t)M * (N/16), "sf too small: have ", sf_data.numel(), " need ", (int64_t)M*(N/16));
TORCH_CHECK(fp4_data.numel() >= (int64_t)M * (N/2), "fp4 too small: have ", fp4_data.numel(), " need ", (int64_t)M*(N/2));
```
You don't need the CUDA kernel here at all (see PART 1) — these weights are dequanted **once at load**, so there's zero performance reason to use a custom kernel for them.
---
## PART 1 — PyTorch quick version (all floor fixes together, simple, no crash)
Goal: one combined config, pure PyTorch, prove correctness end-to-end. This also sidesteps the OOB by not using the CUDA dequant for weights.
1. **Swap the three weight-dequant call sites (328/333/426) to the PyTorch reference.** The CUDA `dequantize_nvfp4(kv_w, kv_ws, gsa)` becomes the PyTorch `dequant_nvfp4(kv_w, kv_ws, kv_ws2, kv_isc)` — and you can delete the manual `gsa = torch.tensor([ws2_v]*shape[0])` lines, because the PyTorch version handles `weight_scale_2` / `input_scale` internally. Be explicit about *which* function you import (they're nearly identically named — that's how this got crossed). Example:
```python
from single_shot_inference import dequant_nvfp4 as dequant_nvfp4_torch # the pure-PyTorch one
# kv_proj:
self._kv_bf16 = dequant_nvfp4_torch(kv_w.to(dev), kv_ws.to(dev), kv_ws2, kv_isc).to(dev).contiguous()
# gate_proj, wp: same pattern
```
2. **LM head → BF16, router gate → BF16.** Dequant their FP4 weights to BF16 once at load via the same PyTorch path, then run them as plain `F.linear`. (The gate is tiny; the LM head is the only sizable one and it's ~1.4 GB — negligible against the KV/concurrency budget.)
3. **Keep the CSA indexer QK path in FP4 — do NOT dequant it.** Only the QK projection of the indexer was QATed. Its non-QATed siblings in the compressor go to BF16 with everything else.
4. **Run a clean generation** with the fixed chat template (the official `encoding/encoding_dsv4.py`, not the hand-rolled path). Confirm: coherent, **no repetition loop**, **clean stop**, Paris top-1 on the canonical probe, and run **≥ a few hundred tokens** so HCA actually engages (HCA's first compressed entry only forms at 128 tokens).
5. **A/B insurance:** this is the all-at-once config. If it regresses versus the individual fixes, flip one component FP4↔BF16 at a time to find the interaction — and record which ones were necessary (that table is the NVIDIA-writeup evidence).
---
## PART 2 — Native CuteDSL / CUDA version
Only after PART 1 validates the combined config (it becomes your reference for it).
1. **Fix the weight dequant path** (you have two options; pick one):
- *Simplest:* keep dequanting these few weights to BF16 **at load in PyTorch** (PART 1) even in the native build. It's a one-time load op — no hot-path cost — so there's no need to native-ize it at all.
- *If you insist on the CUDA kernel for load:* add the `numel`/contiguity guards above, then make the scale match what the kernel reads. The raw checkpoint `weight_scale` appears row-major **before** `finalize_weights` (the production GEMM swizzles at finalize — see the "K-major + swizzle" step ~line 1352 — so the *raw* scale is unswizzled). The guards will tell you if it's actually `(M, N/16)` contiguous; if not, make it contiguous before launch or teach the kernel the real stride. Also: the kernel was built around `N=512`; for weights `N=in` (≈7168) — make sure nothing downstream hardcodes 512.
2. **Hot-path natives are unchanged:** FP8 FMHA, FP4 MoE, and the **FP4 CSA indexer QK** all stay as they are. The floor change only touches load-time weight handling + two small GEMMs (gate, lm_head) that run as native **BF16** (cuBLAS/standard), not FP4.
3. **Re-validate per-layer cosine** of the native build against the PART 1 PyTorch combined-config reference before declaring done.
---
## Guardrails
- Don't reintroduce the **CUDA** `dequantize_nvfp4` for **weights** until the wrapper guards are in and the scale layout is confirmed — for now the PyTorch dequant is correct and crash-proof.
- The two functions `dequant_nvfp4` (PyTorch, weights) and `dequantize_nvfp4` (CUDA, activations/KV) are a foot-gun. Consider renaming the CUDA one to `dequantize_nvfp4_kvcache` so this can't recur.
- Only the **CSA indexer QK** path is FP4-QATed — do not let FP4 creep onto its non-QATed siblings.
- Validate end-to-end (coherent + non-looping + clean stop + HCA-depth) **before** calling it done.

View File

@@ -9,6 +9,7 @@ NO PyTorch SDPA fallback. NO dequant+matmul for production projections.
This is the ground truth for vLLM / SGLang integration.
"""
import os, sys, time, json, math, argparse, logging
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # Catch async CUDA errors immediately
import torch
import torch.nn.functional as F
from pathlib import Path
@@ -302,6 +303,8 @@ class Compressor:
self.is_csa = (ratio == 4); self.kv_dim = 2 * head_dim if self.is_csa else head_dim
self.kv_lin = None # production Nvfp4Linear for kv_proj
self.gate_lin = None # production Nvfp4Linear for gate_proj
self._kv_bf16 = None # BF16 weight for kv_proj (dequantized from NVFP4)
self._gate_bf16 = None # BF16 weight for gate_proj (dequantized from NVFP4)
self.ape = None; self.kv_norm_w = None
self._reduce_loaded = False
# P7: Decode buffering — accumulate hidden_states until we have a complete block.
@@ -312,26 +315,24 @@ class Compressor:
self._buf_len = 0
def load(self, w, pfx, dev=None):
"""Load weights and build production Nvfp4Linear instances."""
"""Load weights and build BF16 projections (dequantized from NVFP4)."""
if dev is None: dev = self.device
# Build production NVFP4 GEMM instances for the two projections
# kv_proj: in=7168, out=kv_dim (1024 for CSA, 512 for HCA)
# gate_proj: same shapes
# Compressor projections are NOT explicitly FP4-QATed — dequant to BF16, use F.linear
# CRITICAL: Use the PyTorch dequant_nvfp4 (defined in this file), NOT the CUDA
# dequantize_nvfp4 from dsv4/ops/quantize.py. The CUDA kernel assumes
# activation/KV scale layout (row-major (M, N/16)) and crashes on weight scales
# that don't match — async illegal memory access surfaces at next sync.
kv_w, kv_ws, kv_ws2, kv_isc = get_nvfp4_weight(w, pfx, 'kv_proj')
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(w, pfx, 'gate_proj')
if kv_w is not None:
kv_out = kv_w.shape[0] # N_packed
kv_in = kv_w.shape[1] * 2 # K_packed * 2
self.kv_lin = make_nvfp4_linear(kv_in, kv_out, dev, w, pfx, 'kv_proj')
self._kv_bf16 = dequant_nvfp4(kv_w.to(dev), kv_ws.to(dev), kv_ws2, kv_isc).to(dev).contiguous()
if gate_w is not None:
gate_out = gate_w.shape[0]
gate_in = gate_w.shape[1] * 2
self.gate_lin = make_nvfp4_linear(gate_in, gate_out, dev, w, pfx, 'gate_proj')
self._gate_bf16 = dequant_nvfp4(gate_w.to(dev), gate_ws.to(dev), gate_ws2, gate_isc).to(dev).contiguous()
self.ape = w.get(f"{pfx}.position_bias")
self.kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
def forward(self, hidden_states, positions):
if self.ratio == 0 or self.kv_lin is None: return None, None, None
if self.ratio == 0 or self._kv_bf16 is None: return None, None, None
T = hidden_states.shape[0]; r = self.ratio; dev = hidden_states.device
# P7: Buffer decode steps until we have a complete block.
@@ -358,9 +359,9 @@ class Compressor:
n_complete = T // r
if n_complete == 0: return None, None, None
# Step 1-2: NVFP4 GEMM projections → FP32 for compress
kv = self.kv_lin(hidden_states).float() # (T, kv_dim) FP32
gate = self.gate_lin(hidden_states).float() # (T, kv_dim) FP32
# Step 1-2: BF16 F.linear projections → FP32 for compress
kv = torch.nn.functional.linear(hidden_states, self._kv_bf16).float() # (T, kv_dim) FP32
gate = torch.nn.functional.linear(hidden_states, self._gate_bf16).float() # (T, kv_dim) FP32
# Step 3: CUDA softmax/reduce kernel → FP32
# KV-1/KV-2: Return FP32. Caller applies RoPE, then quantizes to NVFP4.
@@ -398,22 +399,23 @@ class Indexer:
"""
def __init__(self, n_ih, ihd, top_k, device):
self.n_ih, self.ihd, self.top_k, self.device = n_ih, ihd, top_k, device
self.q_b_lin = None # production Nvfp4Linear for q_b_proj
self.wp_lin = None # production Nvfp4Linear for weights_proj
self.q_b_lin = None # production Nvfp4Linear for q_b_proj (FP4-QATed)
self._wp_bf16 = None # BF16 weight for weights_proj (dequantized from NVFP4)
self.compressor = None
def load(self, w, pfx, dev=None):
if dev is None: dev = self.device
qb_w, qb_ws, qb_ws2, qb_isc = get_nvfp4_weight(w, pfx, 'q_b_proj')
wp_w, wp_ws, wp_ws2, wp_isc = get_nvfp4_weight(w, pfx, 'weights_proj')
# q_b_proj IS the FP4-QATed QK path — keep as NVFP4
if qb_w is not None:
qb_out = qb_w.shape[0]
qb_in = qb_w.shape[1] * 2
self.q_b_lin = make_nvfp4_linear(qb_in, qb_out, dev, w, pfx, 'q_b_proj')
# weights_proj is NOT FP4-QATed — dequant to BF16 via PyTorch reference
# CRITICAL: Use PyTorch dequant_nvfp4, NOT CUDA dequantize_nvfp4 (see Compressor.load)
if wp_w is not None:
wp_out = wp_w.shape[0]
wp_in = wp_w.shape[1] * 2
self.wp_lin = make_nvfp4_linear(wp_in, wp_out, dev, w, pfx, 'weights_proj')
self._wp_bf16 = dequant_nvfp4(wp_w.to(dev), wp_ws.to(dev), wp_ws2, wp_isc).to(dev).contiguous()
# Indexer compressor weights are directly under the indexer prefix
# (e.g. *.indexer.kv_proj.weight), NOT nested under *.indexer.compressor.
if f"{pfx}.kv_proj.weight" in w:
@@ -436,7 +438,7 @@ class Indexer:
li = layer_idx
q_idx = self.q_b_lin(q_lora).reshape(T, self.n_ih, self.ihd) # (T, n_ih, ihd)
w_h = self.wp_lin(hidden_states) # (T, n_ih)
w_h = torch.nn.functional.linear(hidden_states, self._wp_bf16) # (T, n_ih) BF16
# B2: FP8 tensor-core scoring path.
# Indexer keys are stored as FP8_E4M3 in the KV cache.
@@ -1306,50 +1308,26 @@ def main():
router.load_weights(hash_lut=all_w[f"{pfx}.gate.tid2eid"].to(dev, torch.int32))
else:
eb = all_w.get(f"{pfx}.gate.e_score_correction_bias")
# NVFP4 production GEMM for router gate
# Custom CuTeDSL fused kernel crashes MLIR optimizer,
# so we use Nvfp4Linear (proven production path).
from dsv4.layers.linear import Nvfp4Linear
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate')
# BF16 router gate — dequantize NVFP4 to BF16, use F.linear
E = cfg["n_routed_experts"]
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate')
if gate_w is not None and gate_ws is not None:
# Checkpoint has NVFP4 gate weight (N_packed, K_packed) — correct layout
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
gate_w_view = gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev)
gate_lin.fp4 = [gate_w_view]
gate_lin.sf = [gate_ws.to(dev)]
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
gate_lin.gs = [1.0]
gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
gate_lin._activation_global_scale = isc_v # placeholder — runtime gsa overrides this
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
gate_lin.finalize_weights()
router.load_nvfp4_gate(gate_lin)
router.load_weights(e_bias=eb.to(dev, torch.float32))
if li < 5: print(f" L{li}: NVFP4 router gate (checkpoint)", flush=True)
# Checkpoint has NVFP4 gate weight — dequantize to BF16
# CRITICAL: Use PyTorch dequant_nvfp4, NOT CUDA dequantize_nvfp4
# (same fix as Compressor.load — CUDA kernel crashes on weight scale layouts)
gate_bf16 = dequant_nvfp4(gate_w.to(dev), gate_ws.to(dev), gate_ws2, gate_isc)
router.W_gate = gate_bf16.T.contiguous().to(dev) # (H, E) for F.linear(x, W_gate.T)
else:
# BF16 gate weight: quantize to NVFP4
# BF16 gate weight from checkpoint
gw = all_w.get(f"{pfx}.gate.weight")
if gw is not None:
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
g_bf16 = g_bf16.bfloat16().to(dev)
from dsv4.ops.quantize import quantize_to_nvfp4
g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16)
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
gate_lin.fp4 = [g_fp4]
gate_lin.sf = [g_sf]
gate_lin.gs = [g_gs]
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder — runtime gsa overrides
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
gate_lin.finalize_weights()
router.load_nvfp4_gate(gate_lin)
router.load_weights(e_bias=eb.to(dev, torch.float32))
if li < 5: print(f" L{li}: NVFP4 router gate (quantized, gs={g_gs:.6f})", flush=True)
else:
router.load_weights(e_bias=eb.to(dev, torch.float32))
router.load_weights(e_bias=eb.to(dev, torch.float32))
gate_bf16 = gw.bfloat16().to(dev)
if gate_bf16.shape[0] != H:
gate_bf16 = gate_bf16.T.contiguous() # ensure (H, E)
router.W_gate = gate_bf16.contiguous()
# No gate_lin — force BF16 dispatch path
router.gate_lin = None
router.load_weights(e_bias=eb.to(dev, torch.float32))
if li < 5: print(f" L{li}: BF16 router gate (dequantized from NVFP4)", flush=True)
router.finalize_weights(); routers[li] = router
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
@@ -1397,21 +1375,11 @@ def main():
torch.cuda.set_device(0)
embed_w = all_w.get("model.embed_tokens.weight")
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0'))
# lm_head: NVFP4 production GEMM
# lm_head: BF16 GEMM (checkpoint weight is BF16, no quantization)
lm_w_raw = all_w.get("lm_head.weight", embed_w).bfloat16().to('cuda:0')
from dsv4.layers.linear import Nvfp4Linear
lm_head_lin = Nvfp4Linear(lm_w_raw.shape[1], lm_w_raw.shape[0], max_num_tokens=8192, device='cuda:0')
from dsv4.ops.quantize import quantize_weight_to_nvfp4
lm_fp4, lm_sf, lm_gs = quantize_weight_to_nvfp4(lm_w_raw.T.contiguous())
lm_head_lin.fp4 = [lm_fp4.permute(1, 0).contiguous()]
lm_head_lin.sf = [lm_sf.permute(1, 0).contiguous()]
lm_head_lin.gs = [lm_gs]
lm_head_lin.ws2 = [None]
lm_head_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
lm_head_lin._use_runtime_gsa = True
lm_head_lin.finalize_weights()
lm_w = None
print(" lm_head: NVFP4 production GEMM")
lm_head_lin = None # Use raw BF16 F.linear for lm_head
lm_w = lm_w_raw # Keep as (V, H) BF16 for F.linear
print(" lm_head: BF16 GEMM (checkpoint weight, no quantization)")
final_norm_w = all_w.get("model.norm.weight")
if final_norm_w is not None: final_norm_w = final_norm_w.to('cuda:0', torch.float32)
@@ -1660,8 +1628,8 @@ def main():
gl._activation_global_scale = fixed_gsa
gl._use_runtime_gsa = False
n_fixed += 1
# lm_head
if hasattr(lm_head_lin, '_gsa_buf') and hasattr(lm_head_lin, '_use_runtime_gsa') and lm_head_lin._use_runtime_gsa:
# lm_head (BF16 — no gsa needed)
if lm_head_lin is not None and hasattr(lm_head_lin, '_gsa_buf') and hasattr(lm_head_lin, '_use_runtime_gsa') and lm_head_lin._use_runtime_gsa:
fixed_gsa = lm_head_lin._gsa_buf.item()
lm_head_lin._activation_global_scale = fixed_gsa
lm_head_lin._use_runtime_gsa = False
@@ -1669,7 +1637,7 @@ def main():
print(f" Warmup gsa: fixed {n_fixed} projection gsa values from step 0 (MoE/SE keep runtime gsa)", flush=True)
x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :]
if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w)
logits = lm_head_lin(x_out)
logits = torch.nn.functional.linear(x_out, lm_w) if lm_head_lin is None else lm_head_lin(x_out)
if profile: torch.cuda.synchronize()
t_lm = time.perf_counter()
# Check thinking start token logit on first step