Compare commits
9 Commits
pure-nvfp4
...
v-precisio
| Author | SHA1 | Date | |
|---|---|---|---|
| 4fe73fe713 | |||
| f577ed97f4 | |||
| 1121cd7b47 | |||
| f3bb0ca08c | |||
| 470e65fb19 | |||
| 2dd16d5789 | |||
| 95e45a87e3 | |||
| ef94c48957 | |||
| 715602c87c |
94
GETTING+CUDAGRAPH_READY.md
Normal file
94
GETTING+CUDAGRAPH_READY.md
Normal file
@@ -0,0 +1,94 @@
|
||||
# DSV4 → vLLM: CUDA-Graph Safety / GPU-Native Requirements (PART 2 companion)
|
||||
|
||||
**Goal:** the per-step decode forward must be fully GPU-native so vLLM can capture and replay it. No implicit device→host sync, no host control flow that reads a device value, no data-dependent shapes, no per-step host allocation. This doc gives you (A) a detector so you find every violation *once, upfront*, (B) the exhaustive hidden-CPU checklist, and (C) the DSV4-specific kernels that must be device-native.
|
||||
|
||||
## The one rule that decides everything
|
||||
|
||||
Branching on a **host-known integer** (step number, position, batch size, dtype, static shape) is graph-compatible — you capture one graph per bucket and the scheduler picks by that integer. Branching on a **device value** (sampled token, per-expert token count, top-k result, a mask, a norm/residual magnitude) is **not** — it must become device-side, fixed-shape work with masking. Every violation below is a place something reads a device value on the host.
|
||||
|
||||
You do **not** need one monolithic graph. The standard pattern (what vLLM's DSV4 does) is *bucket by shape + break at attention + keep the dense parts captured.* Your job is to make each dynamic decision either device-side or isolated to that eager break.
|
||||
|
||||
---
|
||||
|
||||
## SECTION A — The detector (build this FIRST, before porting anything)
|
||||
|
||||
Stop hunting syncs by hand. Make them fail at the exact line:
|
||||
|
||||
```python
|
||||
import torch
|
||||
torch.cuda.set_sync_debug_mode("error") # raises at any implicit device→host sync
|
||||
# ... run one decode step of the forward ...
|
||||
torch.cuda.set_sync_debug_mode("default")
|
||||
```
|
||||
|
||||
And a capture-under-test (most illegal host ops error *during* capture):
|
||||
```python
|
||||
g = torch.cuda.CUDAGraph()
|
||||
# static input buffers allocated ONCE, outside capture:
|
||||
with torch.cuda.graph(g):
|
||||
out = decode_step(static_inputs) # capture fails loudly on .item(), sync, alloc, etc.
|
||||
for _ in range(3):
|
||||
static_inputs.copy_(next_inputs); g.replay() # replay must reproduce eager output
|
||||
```
|
||||
|
||||
**Do this on the current `single_shot` forward first** — it inventories *every* existing sync in one pass, so you get the whole hunt-list upfront instead of discovering them one at a time during vLLM bring-up. Then gate every commit on both checks in CI; the day someone adds a `.item()`, the build fails at that line.
|
||||
|
||||
Also useful: `compute-sanitizer --tool synccheck`, and `nsys` to eyeball CPU↔GPU stall gaps.
|
||||
|
||||
---
|
||||
|
||||
## SECTION B — The hidden-CPU checklist (grep the hot path for these)
|
||||
|
||||
**Explicit device→host transfers**
|
||||
`.item()` · `.cpu()` · `.tolist()` · `.numpy()` · `int(t)` / `float(t)` / `bool(t)` · `print(t)` · f-strings/logging that interpolate a tensor · `assert (device_condition)` (e.g. `assert (x>0).all()`) · `.to("cpu")`
|
||||
|
||||
**Host control flow on device values**
|
||||
`if t:` · `if mask.any():` · `if x.sum() > thr:` · `while t > 0:` · `for i in range(n.item())` · convergence early-exit reading a device residual · choosing a kernel based on the sampled token
|
||||
|
||||
**Data-dependent shapes (these both change shape AND sync)**
|
||||
`torch.nonzero` · `torch.where(cond)` (one-arg form) · `torch.unique` · `torch.bincount` (when it drives a shape) · boolean/mask indexing `x[mask]`, `x[x>0]` · `masked_select` · `reshape(n.item(), ...)` · any gather sized by a device-computed count
|
||||
|
||||
**Per-step host allocation**
|
||||
`torch.empty/zeros/tensor([...])` created fresh inside the captured region · building a Python list then `torch.tensor(list, device=...)` · `np.*` anywhere on the path · any CPU tensor then `.to(device)` per step
|
||||
|
||||
**Host RNG**
|
||||
`random.*` / `np.random.*` / Python rng for sampling → use a device generator / captured philox state
|
||||
|
||||
**Sync primitives & checks**
|
||||
`torch.cuda.synchronize()` · `stream.synchronize()` · `torch.isnan(x).any()` / `isinf(...).any()` debug guards · pinned-copy syncs
|
||||
|
||||
**Sneaky ones (the "didn't realize" category)**
|
||||
`sum(t)` / `min(t)` / `max(t)` (Python builtins iterate → sync; use `t.sum()`) · a `.cpu()`/`.item()` hidden inside a logging, assert, or metrics helper · `einops` rearrange with a data-dependent dim · telemetry/progress hooks that read tensors · indexing a tensor with a Python int derived from `.item()`
|
||||
|
||||
**What is FINE (no sync, don't waste time on these)**
|
||||
`.shape` / `.size()` / `.numel()` / `.dtype` (host metadata, no sync) · branching on host-known ints (step/batch/static shape) · dtype/shape kernel dispatch · the **stop-token check, detokenize, and your BF16 precision-floor dequant** (all load-time or *outside* the captured graph — leave them on host, that's correct).
|
||||
|
||||
---
|
||||
|
||||
## SECTION C — DSV4-specific kernels that must be GPU-native
|
||||
|
||||
| # | Hazard (current host/dynamic behavior) | Requirement | vLLM reference |
|
||||
|---|---|---|---|
|
||||
| 1 | Compressor returns `None` for 3/4 (CSA) or 127/128 (HCA) decode steps — periodic host branch | Compress **every** step into a persistent partial-state/ring buffer; emit the compressed entry **device-side** on the boundary | `save_partial_states`, `fused_compress_quant_cache` |
|
||||
| 2 | KV grows each step → attention shape changes | Paged KV (fixed blocks + block table) captured at fixed max-len with masking, **or** make attention the eager break | `breakable_cudagraph` / `eager_break_during_capture`; `AttentionCGSupport.ALWAYS` |
|
||||
| 3 | Indexer top-k → host reads selected count to size gather | Always gather fixed `k` (padded), mask invalid; no host read of the count | `dequant_gather_k_cutedsl` (fixed-shape gather) |
|
||||
| 4 | MoE top-6 → per-expert token counts drive per-expert launches | Routing permutation/offsets computed **on device**; grouped GEMM with device offsets and a fixed total launch | `prepare_megamoe` |
|
||||
| 5 | Next token / positions managed on host, fresh tensors per step | Static I/O buffers allocated once; **in-place** `copy_` of next token; positions via device-side increment (or per-shape bucketed graphs) | vLLM persistent input buffers |
|
||||
|
||||
Also confirm:
|
||||
- **Sinkhorn** runs a **fixed 20 iterations with no host convergence check** (a `while not converged` reading a device residual breaks capture). Fixed-iteration = safe.
|
||||
- **Sampler** is device-side; `repetition_penalty` reads from a **fixed-size device** recent-token buffer (not a growing Python list); the EOS/stop decision is a host step **outside** the graph (correct).
|
||||
|
||||
---
|
||||
|
||||
## SECTION D — Integration order
|
||||
|
||||
1. **Build Section A's detector and run it on the current forward** — get the full sync inventory in one pass.
|
||||
2. Fix Section C's five device-native kernels (these are the structural ones; the rest of Section B tends to be incidental `.item()`s once these are right).
|
||||
3. Re-run capture-under-test until it captures clean and replay matches eager bit-for-bit.
|
||||
4. Gate every commit on the capture test so violations can never silently return.
|
||||
|
||||
## Guardrails
|
||||
- Keep the stop-check, detokenize, and load-time BF16 dequant on the host — they're outside the captured region by design; don't contort them to be "graph-safe."
|
||||
- Decide the attention model up front (paged-capturable vs eager-break) — retrofitting it later forces a KV-cache rewrite.
|
||||
- Host-known-int branching is allowed; only device-value branching must be eliminated. Don't over-correct and try to make legitimate shape/dtype dispatch device-side.
|
||||
69
WALKING_BACK_SOME_QUANTS.md
Normal file
69
WALKING_BACK_SOME_QUANTS.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# DSV4 Precision Floor — PyTorch Validation (PART 1) + Native Port (PART 2)
|
||||
|
||||
**What we learned:** the NVFP4 precision floor for this model is — keep **LM head** BF16, **router gate** BF16, and the **compressor/indexer helper projections** BF16, with the **one exception** that the **CSA indexer QK path stays FP4** (it was explicitly FP4-QATed; the other compressor projections were not, so PTQ-ing them to FP4 breaks). We validated each individually. Now do all of them together, simple-PyTorch first, then native.
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ First: the CUDA illegal-memory-access (you're calling the wrong dequant)
|
||||
|
||||
There are **two** functions with nearly the same name:
|
||||
|
||||
- `single_shot_inference.py:238` — `dequant_nvfp4(weight, weight_scale, weight_scale_2, input_scale)` — **pure PyTorch** (does `weight_scale.repeat_interleave(16,1) * scales`). This is what `nvfp4_linear_ref` uses — your **validated reference**. It cannot cause an illegal access.
|
||||
- `dsv4/ops/quantize.py:377` — `dequantize_nvfp4(x_fp4, x_sf, gsa)` — calls the **CUDA kernel** `dequant_nvfp4.cu`. **This is the one crashing.**
|
||||
|
||||
The precision-floor code (lines 328 / 333 / 426: kv_proj, gate_proj, wp) imports the **CUDA** one and feeds it **weights**. But that kernel was written for the **activation / KV-gather** path — read its own docstring: *"compressed KV is stored as NVFP4, dequantized on-the-fly."* It assumes row-major `(M, N/16)` block scales, per-row `gsa`, `N=512`.
|
||||
|
||||
The host wrapper only does `TORCH_CHECK(sf_data.size(0) == M)` — it validates the scale's **row count and nothing else** (not width, not total size, not contiguity). The kernel then indexes `sf_data[m*(N/16) + n_block]` flat. For a weight whose scale isn't *exactly* contiguous row-major `(M, N/16)` — different width, padding, non-contiguous `.to(dev)` view, or the GEMM swizzle — that index walks off the allocation → **async illegal access, surfacing at the next sync (the compressor load).** The activation/KV path never tripped it because those scales already match the assumed layout.
|
||||
|
||||
**Confirm it in 2 minutes** (the error is async, so do this to localize it):
|
||||
```bash
|
||||
compute-sanitizer --tool memcheck <your harness> ... # will name dequant_nvfp4_kernel + the sf_data read
|
||||
# or: CUDA_LAUNCH_BLOCKING=1 to move the report to the offending launch
|
||||
```
|
||||
And add these guards to `dequant_nvfp4_cuda` in `dequant_nvfp4.cu` — they turn the async crash into an immediate, located error and print the size mismatch:
|
||||
```cpp
|
||||
TORCH_CHECK(fp4_data.is_contiguous() && sf_data.is_contiguous(), "dequant inputs must be contiguous");
|
||||
TORCH_CHECK(sf_data.numel() >= (int64_t)M * (N/16), "sf too small: have ", sf_data.numel(), " need ", (int64_t)M*(N/16));
|
||||
TORCH_CHECK(fp4_data.numel() >= (int64_t)M * (N/2), "fp4 too small: have ", fp4_data.numel(), " need ", (int64_t)M*(N/2));
|
||||
```
|
||||
|
||||
You don't need the CUDA kernel here at all (see PART 1) — these weights are dequanted **once at load**, so there's zero performance reason to use a custom kernel for them.
|
||||
|
||||
---
|
||||
|
||||
## PART 1 — PyTorch quick version (all floor fixes together, simple, no crash)
|
||||
|
||||
Goal: one combined config, pure PyTorch, prove correctness end-to-end. This also sidesteps the OOB by not using the CUDA dequant for weights.
|
||||
|
||||
1. **Swap the three weight-dequant call sites (328/333/426) to the PyTorch reference.** The CUDA `dequantize_nvfp4(kv_w, kv_ws, gsa)` becomes the PyTorch `dequant_nvfp4(kv_w, kv_ws, kv_ws2, kv_isc)` — and you can delete the manual `gsa = torch.tensor([ws2_v]*shape[0])` lines, because the PyTorch version handles `weight_scale_2` / `input_scale` internally. Be explicit about *which* function you import (they're nearly identically named — that's how this got crossed). Example:
|
||||
```python
|
||||
from single_shot_inference import dequant_nvfp4 as dequant_nvfp4_torch # the pure-PyTorch one
|
||||
# kv_proj:
|
||||
self._kv_bf16 = dequant_nvfp4_torch(kv_w.to(dev), kv_ws.to(dev), kv_ws2, kv_isc).to(dev).contiguous()
|
||||
# gate_proj, wp: same pattern
|
||||
```
|
||||
2. **LM head → BF16, router gate → BF16.** Dequant their FP4 weights to BF16 once at load via the same PyTorch path, then run them as plain `F.linear`. (The gate is tiny; the LM head is the only sizable one and it's ~1.4 GB — negligible against the KV/concurrency budget.)
|
||||
3. **Keep the CSA indexer QK path in FP4 — do NOT dequant it.** Only the QK projection of the indexer was QATed. Its non-QATed siblings in the compressor go to BF16 with everything else.
|
||||
4. **Run a clean generation** with the fixed chat template (the official `encoding/encoding_dsv4.py`, not the hand-rolled path). Confirm: coherent, **no repetition loop**, **clean stop**, Paris top-1 on the canonical probe, and run **≥ a few hundred tokens** so HCA actually engages (HCA's first compressed entry only forms at 128 tokens).
|
||||
5. **A/B insurance:** this is the all-at-once config. If it regresses versus the individual fixes, flip one component FP4↔BF16 at a time to find the interaction — and record which ones were necessary (that table is the NVIDIA-writeup evidence).
|
||||
|
||||
---
|
||||
|
||||
## PART 2 — Native CuteDSL / CUDA version
|
||||
|
||||
Only after PART 1 validates the combined config (it becomes your reference for it).
|
||||
|
||||
1. **Fix the weight dequant path** (you have two options; pick one):
|
||||
- *Simplest:* keep dequanting these few weights to BF16 **at load in PyTorch** (PART 1) even in the native build. It's a one-time load op — no hot-path cost — so there's no need to native-ize it at all.
|
||||
- *If you insist on the CUDA kernel for load:* add the `numel`/contiguity guards above, then make the scale match what the kernel reads. The raw checkpoint `weight_scale` appears row-major **before** `finalize_weights` (the production GEMM swizzles at finalize — see the "K-major + swizzle" step ~line 1352 — so the *raw* scale is unswizzled). The guards will tell you if it's actually `(M, N/16)` contiguous; if not, make it contiguous before launch or teach the kernel the real stride. Also: the kernel was built around `N=512`; for weights `N=in` (≈7168) — make sure nothing downstream hardcodes 512.
|
||||
2. **Hot-path natives are unchanged:** FP8 FMHA, FP4 MoE, and the **FP4 CSA indexer QK** all stay as they are. The floor change only touches load-time weight handling + two small GEMMs (gate, lm_head) that run as native **BF16** (cuBLAS/standard), not FP4.
|
||||
3. **Re-validate per-layer cosine** of the native build against the PART 1 PyTorch combined-config reference before declaring done.
|
||||
|
||||
---
|
||||
|
||||
## Guardrails
|
||||
|
||||
- Don't reintroduce the **CUDA** `dequantize_nvfp4` for **weights** until the wrapper guards are in and the scale layout is confirmed — for now the PyTorch dequant is correct and crash-proof.
|
||||
- The two functions `dequant_nvfp4` (PyTorch, weights) and `dequantize_nvfp4` (CUDA, activations/KV) are a foot-gun. Consider renaming the CUDA one to `dequantize_nvfp4_kvcache` so this can't recur.
|
||||
- Only the **CSA indexer QK** path is FP4-QATed — do not let FP4 creep onto its non-QATed siblings.
|
||||
- Validate end-to-end (coherent + non-looping + clean stop + HCA-depth) **before** calling it done.
|
||||
@@ -9,6 +9,7 @@ NO PyTorch SDPA fallback. NO dequant+matmul for production projections.
|
||||
This is the ground truth for vLLM / SGLang integration.
|
||||
"""
|
||||
import os, sys, time, json, math, argparse, logging
|
||||
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # Catch async CUDA errors immediately
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pathlib import Path
|
||||
@@ -302,6 +303,8 @@ class Compressor:
|
||||
self.is_csa = (ratio == 4); self.kv_dim = 2 * head_dim if self.is_csa else head_dim
|
||||
self.kv_lin = None # production Nvfp4Linear for kv_proj
|
||||
self.gate_lin = None # production Nvfp4Linear for gate_proj
|
||||
self._kv_bf16 = None # BF16 weight for kv_proj (dequantized from NVFP4)
|
||||
self._gate_bf16 = None # BF16 weight for gate_proj (dequantized from NVFP4)
|
||||
self.ape = None; self.kv_norm_w = None
|
||||
self._reduce_loaded = False
|
||||
# P7: Decode buffering — accumulate hidden_states until we have a complete block.
|
||||
@@ -312,26 +315,24 @@ class Compressor:
|
||||
self._buf_len = 0
|
||||
|
||||
def load(self, w, pfx, dev=None):
|
||||
"""Load weights and build production Nvfp4Linear instances."""
|
||||
"""Load weights and build BF16 projections (dequantized from NVFP4)."""
|
||||
if dev is None: dev = self.device
|
||||
# Build production NVFP4 GEMM instances for the two projections
|
||||
# kv_proj: in=7168, out=kv_dim (1024 for CSA, 512 for HCA)
|
||||
# gate_proj: same shapes
|
||||
# Compressor projections are NOT explicitly FP4-QATed — dequant to BF16, use F.linear
|
||||
# CRITICAL: Use the PyTorch dequant_nvfp4 (defined in this file), NOT the CUDA
|
||||
# dequantize_nvfp4 from dsv4/ops/quantize.py. The CUDA kernel assumes
|
||||
# activation/KV scale layout (row-major (M, N/16)) and crashes on weight scales
|
||||
# that don't match — async illegal memory access surfaces at next sync.
|
||||
kv_w, kv_ws, kv_ws2, kv_isc = get_nvfp4_weight(w, pfx, 'kv_proj')
|
||||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(w, pfx, 'gate_proj')
|
||||
if kv_w is not None:
|
||||
kv_out = kv_w.shape[0] # N_packed
|
||||
kv_in = kv_w.shape[1] * 2 # K_packed * 2
|
||||
self.kv_lin = make_nvfp4_linear(kv_in, kv_out, dev, w, pfx, 'kv_proj')
|
||||
self._kv_bf16 = dequant_nvfp4(kv_w.to(dev), kv_ws.to(dev), kv_ws2, kv_isc).to(dev).contiguous()
|
||||
if gate_w is not None:
|
||||
gate_out = gate_w.shape[0]
|
||||
gate_in = gate_w.shape[1] * 2
|
||||
self.gate_lin = make_nvfp4_linear(gate_in, gate_out, dev, w, pfx, 'gate_proj')
|
||||
self._gate_bf16 = dequant_nvfp4(gate_w.to(dev), gate_ws.to(dev), gate_ws2, gate_isc).to(dev).contiguous()
|
||||
self.ape = w.get(f"{pfx}.position_bias")
|
||||
self.kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
|
||||
|
||||
def forward(self, hidden_states, positions):
|
||||
if self.ratio == 0 or self.kv_lin is None: return None, None, None
|
||||
if self.ratio == 0 or self._kv_bf16 is None: return None, None, None
|
||||
T = hidden_states.shape[0]; r = self.ratio; dev = hidden_states.device
|
||||
|
||||
# P7: Buffer decode steps until we have a complete block.
|
||||
@@ -358,9 +359,9 @@ class Compressor:
|
||||
n_complete = T // r
|
||||
if n_complete == 0: return None, None, None
|
||||
|
||||
# Step 1-2: NVFP4 GEMM projections → FP32 for compress
|
||||
kv = self.kv_lin(hidden_states).float() # (T, kv_dim) FP32
|
||||
gate = self.gate_lin(hidden_states).float() # (T, kv_dim) FP32
|
||||
# Step 1-2: BF16 F.linear projections → FP32 for compress
|
||||
kv = torch.nn.functional.linear(hidden_states, self._kv_bf16).float() # (T, kv_dim) FP32
|
||||
gate = torch.nn.functional.linear(hidden_states, self._gate_bf16).float() # (T, kv_dim) FP32
|
||||
|
||||
# Step 3: CUDA softmax/reduce kernel → FP32
|
||||
# KV-1/KV-2: Return FP32. Caller applies RoPE, then quantizes to NVFP4.
|
||||
@@ -398,22 +399,23 @@ class Indexer:
|
||||
"""
|
||||
def __init__(self, n_ih, ihd, top_k, device):
|
||||
self.n_ih, self.ihd, self.top_k, self.device = n_ih, ihd, top_k, device
|
||||
self.q_b_lin = None # production Nvfp4Linear for q_b_proj
|
||||
self.wp_lin = None # production Nvfp4Linear for weights_proj
|
||||
self.q_b_lin = None # production Nvfp4Linear for q_b_proj (FP4-QATed)
|
||||
self._wp_bf16 = None # BF16 weight for weights_proj (dequantized from NVFP4)
|
||||
self.compressor = None
|
||||
|
||||
def load(self, w, pfx, dev=None):
|
||||
if dev is None: dev = self.device
|
||||
qb_w, qb_ws, qb_ws2, qb_isc = get_nvfp4_weight(w, pfx, 'q_b_proj')
|
||||
wp_w, wp_ws, wp_ws2, wp_isc = get_nvfp4_weight(w, pfx, 'weights_proj')
|
||||
# q_b_proj IS the FP4-QATed QK path — keep as NVFP4
|
||||
if qb_w is not None:
|
||||
qb_out = qb_w.shape[0]
|
||||
qb_in = qb_w.shape[1] * 2
|
||||
self.q_b_lin = make_nvfp4_linear(qb_in, qb_out, dev, w, pfx, 'q_b_proj')
|
||||
# weights_proj is NOT FP4-QATed — dequant to BF16 via PyTorch reference
|
||||
# CRITICAL: Use PyTorch dequant_nvfp4, NOT CUDA dequantize_nvfp4 (see Compressor.load)
|
||||
if wp_w is not None:
|
||||
wp_out = wp_w.shape[0]
|
||||
wp_in = wp_w.shape[1] * 2
|
||||
self.wp_lin = make_nvfp4_linear(wp_in, wp_out, dev, w, pfx, 'weights_proj')
|
||||
self._wp_bf16 = dequant_nvfp4(wp_w.to(dev), wp_ws.to(dev), wp_ws2, wp_isc).to(dev).contiguous()
|
||||
# Indexer compressor weights are directly under the indexer prefix
|
||||
# (e.g. *.indexer.kv_proj.weight), NOT nested under *.indexer.compressor.
|
||||
if f"{pfx}.kv_proj.weight" in w:
|
||||
@@ -436,7 +438,7 @@ class Indexer:
|
||||
li = layer_idx
|
||||
|
||||
q_idx = self.q_b_lin(q_lora).reshape(T, self.n_ih, self.ihd) # (T, n_ih, ihd)
|
||||
w_h = self.wp_lin(hidden_states) # (T, n_ih)
|
||||
w_h = torch.nn.functional.linear(hidden_states, self._wp_bf16) # (T, n_ih) BF16
|
||||
|
||||
# B2: FP8 tensor-core scoring path.
|
||||
# Indexer keys are stored as FP8_E4M3 in the KV cache.
|
||||
@@ -1306,50 +1308,26 @@ def main():
|
||||
router.load_weights(hash_lut=all_w[f"{pfx}.gate.tid2eid"].to(dev, torch.int32))
|
||||
else:
|
||||
eb = all_w.get(f"{pfx}.gate.e_score_correction_bias")
|
||||
# NVFP4 production GEMM for router gate
|
||||
# Custom CuTeDSL fused kernel crashes MLIR optimizer,
|
||||
# so we use Nvfp4Linear (proven production path).
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate')
|
||||
# BF16 router gate — dequantize NVFP4 to BF16, use F.linear
|
||||
E = cfg["n_routed_experts"]
|
||||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate')
|
||||
if gate_w is not None and gate_ws is not None:
|
||||
# Checkpoint has NVFP4 gate weight (N_packed, K_packed) — correct layout
|
||||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||||
gate_w_view = gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev)
|
||||
gate_lin.fp4 = [gate_w_view]
|
||||
gate_lin.sf = [gate_ws.to(dev)]
|
||||
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
|
||||
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
|
||||
gate_lin.gs = [1.0]
|
||||
gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = isc_v # placeholder — runtime gsa overrides this
|
||||
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
|
||||
gate_lin.finalize_weights()
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
if li < 5: print(f" L{li}: NVFP4 router gate (checkpoint)", flush=True)
|
||||
# Checkpoint has NVFP4 gate weight — dequantize to BF16
|
||||
# CRITICAL: Use PyTorch dequant_nvfp4, NOT CUDA dequantize_nvfp4
|
||||
# (same fix as Compressor.load — CUDA kernel crashes on weight scale layouts)
|
||||
gate_bf16 = dequant_nvfp4(gate_w.to(dev), gate_ws.to(dev), gate_ws2, gate_isc)
|
||||
router.W_gate = gate_bf16.T.contiguous().to(dev) # (H, E) for F.linear(x, W_gate.T)
|
||||
else:
|
||||
# BF16 gate weight: quantize to NVFP4
|
||||
# BF16 gate weight from checkpoint
|
||||
gw = all_w.get(f"{pfx}.gate.weight")
|
||||
if gw is not None:
|
||||
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
|
||||
g_bf16 = g_bf16.bfloat16().to(dev)
|
||||
from dsv4.ops.quantize import quantize_to_nvfp4
|
||||
g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16)
|
||||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||||
gate_lin.fp4 = [g_fp4]
|
||||
gate_lin.sf = [g_sf]
|
||||
gate_lin.gs = [g_gs]
|
||||
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder — runtime gsa overrides
|
||||
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
|
||||
gate_lin.finalize_weights()
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
if li < 5: print(f" L{li}: NVFP4 router gate (quantized, gs={g_gs:.6f})", flush=True)
|
||||
else:
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
gate_bf16 = gw.bfloat16().to(dev)
|
||||
if gate_bf16.shape[0] != H:
|
||||
gate_bf16 = gate_bf16.T.contiguous() # ensure (H, E)
|
||||
router.W_gate = gate_bf16.contiguous()
|
||||
# No gate_lin — force BF16 dispatch path
|
||||
router.gate_lin = None
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
if li < 5: print(f" L{li}: BF16 router gate (dequantized from NVFP4)", flush=True)
|
||||
router.finalize_weights(); routers[li] = router
|
||||
|
||||
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
|
||||
@@ -1397,21 +1375,11 @@ def main():
|
||||
torch.cuda.set_device(0)
|
||||
embed_w = all_w.get("model.embed_tokens.weight")
|
||||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0'))
|
||||
# lm_head: NVFP4 production GEMM
|
||||
# lm_head: BF16 GEMM (checkpoint weight is BF16, no quantization)
|
||||
lm_w_raw = all_w.get("lm_head.weight", embed_w).bfloat16().to('cuda:0')
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
lm_head_lin = Nvfp4Linear(lm_w_raw.shape[1], lm_w_raw.shape[0], max_num_tokens=8192, device='cuda:0')
|
||||
from dsv4.ops.quantize import quantize_weight_to_nvfp4
|
||||
lm_fp4, lm_sf, lm_gs = quantize_weight_to_nvfp4(lm_w_raw.T.contiguous())
|
||||
lm_head_lin.fp4 = [lm_fp4.permute(1, 0).contiguous()]
|
||||
lm_head_lin.sf = [lm_sf.permute(1, 0).contiguous()]
|
||||
lm_head_lin.gs = [lm_gs]
|
||||
lm_head_lin.ws2 = [None]
|
||||
lm_head_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
lm_head_lin._use_runtime_gsa = True
|
||||
lm_head_lin.finalize_weights()
|
||||
lm_w = None
|
||||
print(" lm_head: NVFP4 production GEMM")
|
||||
lm_head_lin = None # Use raw BF16 F.linear for lm_head
|
||||
lm_w = lm_w_raw # Keep as (V, H) BF16 for F.linear
|
||||
print(" lm_head: BF16 GEMM (checkpoint weight, no quantization)")
|
||||
final_norm_w = all_w.get("model.norm.weight")
|
||||
if final_norm_w is not None: final_norm_w = final_norm_w.to('cuda:0', torch.float32)
|
||||
|
||||
@@ -1660,8 +1628,8 @@ def main():
|
||||
gl._activation_global_scale = fixed_gsa
|
||||
gl._use_runtime_gsa = False
|
||||
n_fixed += 1
|
||||
# lm_head
|
||||
if hasattr(lm_head_lin, '_gsa_buf') and hasattr(lm_head_lin, '_use_runtime_gsa') and lm_head_lin._use_runtime_gsa:
|
||||
# lm_head (BF16 — no gsa needed)
|
||||
if lm_head_lin is not None and hasattr(lm_head_lin, '_gsa_buf') and hasattr(lm_head_lin, '_use_runtime_gsa') and lm_head_lin._use_runtime_gsa:
|
||||
fixed_gsa = lm_head_lin._gsa_buf.item()
|
||||
lm_head_lin._activation_global_scale = fixed_gsa
|
||||
lm_head_lin._use_runtime_gsa = False
|
||||
@@ -1669,7 +1637,7 @@ def main():
|
||||
print(f" Warmup gsa: fixed {n_fixed} projection gsa values from step 0 (MoE/SE keep runtime gsa)", flush=True)
|
||||
x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :]
|
||||
if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w)
|
||||
logits = lm_head_lin(x_out)
|
||||
logits = torch.nn.functional.linear(x_out, lm_w) if lm_head_lin is None else lm_head_lin(x_out)
|
||||
if profile: torch.cuda.synchronize()
|
||||
t_lm = time.perf_counter()
|
||||
# Check thinking start token logit on first step
|
||||
|
||||
Reference in New Issue
Block a user