Compare commits
1 Commits
v-precisio
...
pure-nvfp4
| Author | SHA1 | Date | |
|---|---|---|---|
| 3320abfe24 |
59
CORRECTNESS_FIX_ATTEMPTS.md
Normal file
59
CORRECTNESS_FIX_ATTEMPTS.md
Normal file
@@ -0,0 +1,59 @@
|
||||
## 1. Possible bug: compressor positional bias is being added to KV content
|
||||
|
||||
In your `dsv4/kernels/cuda/compressor_reduce.cu`, the compressor appears to do this in both CSA and HCA paths:
|
||||
|
||||
```cpp
|
||||
g += pb;
|
||||
kv_val += pb; // suspicious / wrong
|
||||
```
|
||||
|
||||
The official compressor equations add positional bias only to the **compression weights/logits** `Z + B`, then use those weights to sum the raw projected KV content `C`. The bias is not added to the KV value itself. The paper defines compression as softmax over `Z + B`, followed by a weighted sum of `C`.
|
||||
|
||||
So this should be:
|
||||
|
||||
```cpp
|
||||
g += pb;
|
||||
// do not add pb to kv_val
|
||||
```
|
||||
|
||||
That bug would poison every compressed KV entry with learned positional-bias content. It may not fully explain the first token for a tiny prompt if SWA dominates, but it is absolutely wrong relative to the official architecture and will degrade CSA/HCA context quality. If your unit tests passed, they may have been comparing against a reference that made the same mistake or were too short to expose it.
|
||||
|
||||
## 2. Don’t use `think_start` as the canary here
|
||||
|
||||
In official `thinking` mode, the prompt formatter typically appends the assistant marker plus `<think>` before generation. That means decode step 0 is already *inside* the thinking span. The model should not necessarily emit `think_start`; a low `think_start` logit is not itself evidence that the model “failed to enter thinking mode.”
|
||||
|
||||
For this particular prompt, a high `think_end` logit can even be plausible because “The capital of France is” does not need much reasoning. Run the same current kernel with official **chat mode**, greedy decoding, no repetition penalty, no top-k/top-p, and compare first-token logits. The `think_start = 1.77` observation is probably a misleading diagnostic.
|
||||
|
||||
## 3. Indexer parity still looks suspect
|
||||
|
||||
The official CSA/HCA details include RMSNorm on queries and compressed KV before attention, partial RoPE on the last 64 dims, sliding-window KV, and attention sink behavior.
|
||||
|
||||
For the CSA **indexer**, the official reference does more than a plain q/k dot product: indexer Q and compressed indexer K get the appropriate RoPE/rotation treatment, and the QK path is one of the explicitly FP4-QATed pieces. If your current indexer compressor is just producing compressed keys without the same rotate/RoPE/FP4 path, CSA top-k can select plausible-looking but wrong blocks. Again, probably not the first-token issue on a short prompt, but it will matter for any real context.
|
||||
|
||||
## 4. Check SwiGLU clamp ordering
|
||||
|
||||
The official behavior is effectively:
|
||||
|
||||
```python
|
||||
gate = clamp(gate, max=swiglu_limit)
|
||||
up = clamp(up, -swiglu_limit, swiglu_limit)
|
||||
out = silu(gate) * up
|
||||
```
|
||||
|
||||
If your fused path clamps `silu(gate)` instead of clamping raw `gate` before SiLU, it is not equivalent. This is especially worth checking in both routed MoE and shared expert fused kernels, because a small-looking activation semantic mismatch repeats through every layer.
|
||||
|
||||
|
||||
## 5. DEQUANT TO BF16 IN THIS ORDER JUST TO SEE WHAT HAPPENS (You are allowed to break the no bf16 rule for this because we can always revert back to previous commit)
|
||||
|
||||
|
||||
The most suspicious surfaces to temporarily dequantize are, in order:
|
||||
|
||||
1. **lm head** — FP4 lm head can directly flatten or reorder vocabulary logits.
|
||||
2. **router gate** — slight errors can change top-6 experts; wrong expert IDs are much worse than a small GEMM error.
|
||||
3. **shared expert** — official routed experts are the FP4 target; shared expert is a different sensitivity profile.
|
||||
4. **attention q/kv/o projections and grouped output projection** — these are not described as full FP4 QAT targets.
|
||||
5. **compressor/indexer helper projections** — only the CSA indexer QK path is explicitly FP4-QATed, not the whole compressor.
|
||||
|
||||
If a BF16/FP8 lm head alone makes `Paris` / `.` / answer-like tokens dominate again, you’ve found a high-leverage culprit. My money is on LM Head needing to be BF16
|
||||
|
||||
The fastest triage is basically: run `thinking_mode=chat`, greedy; switch only `lm_head` back to BF16/FP8; then switch router back; then patch the compressor bias-to-KV bug. If any one of those sharply separates the first-token distribution, you’ll know where to spend kernel time.
|
||||
@@ -1,94 +0,0 @@
|
||||
# DSV4 → vLLM: CUDA-Graph Safety / GPU-Native Requirements (PART 2 companion)
|
||||
|
||||
**Goal:** the per-step decode forward must be fully GPU-native so vLLM can capture and replay it. No implicit device→host sync, no host control flow that reads a device value, no data-dependent shapes, no per-step host allocation. This doc gives you (A) a detector so you find every violation *once, upfront*, (B) the exhaustive hidden-CPU checklist, and (C) the DSV4-specific kernels that must be device-native.
|
||||
|
||||
## The one rule that decides everything
|
||||
|
||||
Branching on a **host-known integer** (step number, position, batch size, dtype, static shape) is graph-compatible — you capture one graph per bucket and the scheduler picks by that integer. Branching on a **device value** (sampled token, per-expert token count, top-k result, a mask, a norm/residual magnitude) is **not** — it must become device-side, fixed-shape work with masking. Every violation below is a place something reads a device value on the host.
|
||||
|
||||
You do **not** need one monolithic graph. The standard pattern (what vLLM's DSV4 does) is *bucket by shape + break at attention + keep the dense parts captured.* Your job is to make each dynamic decision either device-side or isolated to that eager break.
|
||||
|
||||
---
|
||||
|
||||
## SECTION A — The detector (build this FIRST, before porting anything)
|
||||
|
||||
Stop hunting syncs by hand. Make them fail at the exact line:
|
||||
|
||||
```python
|
||||
import torch
|
||||
torch.cuda.set_sync_debug_mode("error") # raises at any implicit device→host sync
|
||||
# ... run one decode step of the forward ...
|
||||
torch.cuda.set_sync_debug_mode("default")
|
||||
```
|
||||
|
||||
And a capture-under-test (most illegal host ops error *during* capture):
|
||||
```python
|
||||
g = torch.cuda.CUDAGraph()
|
||||
# static input buffers allocated ONCE, outside capture:
|
||||
with torch.cuda.graph(g):
|
||||
out = decode_step(static_inputs) # capture fails loudly on .item(), sync, alloc, etc.
|
||||
for _ in range(3):
|
||||
static_inputs.copy_(next_inputs); g.replay() # replay must reproduce eager output
|
||||
```
|
||||
|
||||
**Do this on the current `single_shot` forward first** — it inventories *every* existing sync in one pass, so you get the whole hunt-list upfront instead of discovering them one at a time during vLLM bring-up. Then gate every commit on both checks in CI; the day someone adds a `.item()`, the build fails at that line.
|
||||
|
||||
Also useful: `compute-sanitizer --tool synccheck`, and `nsys` to eyeball CPU↔GPU stall gaps.
|
||||
|
||||
---
|
||||
|
||||
## SECTION B — The hidden-CPU checklist (grep the hot path for these)
|
||||
|
||||
**Explicit device→host transfers**
|
||||
`.item()` · `.cpu()` · `.tolist()` · `.numpy()` · `int(t)` / `float(t)` / `bool(t)` · `print(t)` · f-strings/logging that interpolate a tensor · `assert (device_condition)` (e.g. `assert (x>0).all()`) · `.to("cpu")`
|
||||
|
||||
**Host control flow on device values**
|
||||
`if t:` · `if mask.any():` · `if x.sum() > thr:` · `while t > 0:` · `for i in range(n.item())` · convergence early-exit reading a device residual · choosing a kernel based on the sampled token
|
||||
|
||||
**Data-dependent shapes (these both change shape AND sync)**
|
||||
`torch.nonzero` · `torch.where(cond)` (one-arg form) · `torch.unique` · `torch.bincount` (when it drives a shape) · boolean/mask indexing `x[mask]`, `x[x>0]` · `masked_select` · `reshape(n.item(), ...)` · any gather sized by a device-computed count
|
||||
|
||||
**Per-step host allocation**
|
||||
`torch.empty/zeros/tensor([...])` created fresh inside the captured region · building a Python list then `torch.tensor(list, device=...)` · `np.*` anywhere on the path · any CPU tensor then `.to(device)` per step
|
||||
|
||||
**Host RNG**
|
||||
`random.*` / `np.random.*` / Python rng for sampling → use a device generator / captured philox state
|
||||
|
||||
**Sync primitives & checks**
|
||||
`torch.cuda.synchronize()` · `stream.synchronize()` · `torch.isnan(x).any()` / `isinf(...).any()` debug guards · pinned-copy syncs
|
||||
|
||||
**Sneaky ones (the "didn't realize" category)**
|
||||
`sum(t)` / `min(t)` / `max(t)` (Python builtins iterate → sync; use `t.sum()`) · a `.cpu()`/`.item()` hidden inside a logging, assert, or metrics helper · `einops` rearrange with a data-dependent dim · telemetry/progress hooks that read tensors · indexing a tensor with a Python int derived from `.item()`
|
||||
|
||||
**What is FINE (no sync, don't waste time on these)**
|
||||
`.shape` / `.size()` / `.numel()` / `.dtype` (host metadata, no sync) · branching on host-known ints (step/batch/static shape) · dtype/shape kernel dispatch · the **stop-token check, detokenize, and your BF16 precision-floor dequant** (all load-time or *outside* the captured graph — leave them on host, that's correct).
|
||||
|
||||
---
|
||||
|
||||
## SECTION C — DSV4-specific kernels that must be GPU-native
|
||||
|
||||
| # | Hazard (current host/dynamic behavior) | Requirement | vLLM reference |
|
||||
|---|---|---|---|
|
||||
| 1 | Compressor returns `None` for 3/4 (CSA) or 127/128 (HCA) decode steps — periodic host branch | Compress **every** step into a persistent partial-state/ring buffer; emit the compressed entry **device-side** on the boundary | `save_partial_states`, `fused_compress_quant_cache` |
|
||||
| 2 | KV grows each step → attention shape changes | Paged KV (fixed blocks + block table) captured at fixed max-len with masking, **or** make attention the eager break | `breakable_cudagraph` / `eager_break_during_capture`; `AttentionCGSupport.ALWAYS` |
|
||||
| 3 | Indexer top-k → host reads selected count to size gather | Always gather fixed `k` (padded), mask invalid; no host read of the count | `dequant_gather_k_cutedsl` (fixed-shape gather) |
|
||||
| 4 | MoE top-6 → per-expert token counts drive per-expert launches | Routing permutation/offsets computed **on device**; grouped GEMM with device offsets and a fixed total launch | `prepare_megamoe` |
|
||||
| 5 | Next token / positions managed on host, fresh tensors per step | Static I/O buffers allocated once; **in-place** `copy_` of next token; positions via device-side increment (or per-shape bucketed graphs) | vLLM persistent input buffers |
|
||||
|
||||
Also confirm:
|
||||
- **Sinkhorn** runs a **fixed 20 iterations with no host convergence check** (a `while not converged` reading a device residual breaks capture). Fixed-iteration = safe.
|
||||
- **Sampler** is device-side; `repetition_penalty` reads from a **fixed-size device** recent-token buffer (not a growing Python list); the EOS/stop decision is a host step **outside** the graph (correct).
|
||||
|
||||
---
|
||||
|
||||
## SECTION D — Integration order
|
||||
|
||||
1. **Build Section A's detector and run it on the current forward** — get the full sync inventory in one pass.
|
||||
2. Fix Section C's five device-native kernels (these are the structural ones; the rest of Section B tends to be incidental `.item()`s once these are right).
|
||||
3. Re-run capture-under-test until it captures clean and replay matches eager bit-for-bit.
|
||||
4. Gate every commit on the capture test so violations can never silently return.
|
||||
|
||||
## Guardrails
|
||||
- Keep the stop-check, detokenize, and load-time BF16 dequant on the host — they're outside the captured region by design; don't contort them to be "graph-safe."
|
||||
- Decide the attention model up front (paged-capturable vs eager-break) — retrofitting it later forces a KV-cache rewrite.
|
||||
- Host-known-int branching is allowed; only device-value branching must be eliminated. Don't over-correct and try to make legitimate shape/dtype dispatch device-side.
|
||||
@@ -1,69 +0,0 @@
|
||||
# DSV4 Precision Floor — PyTorch Validation (PART 1) + Native Port (PART 2)
|
||||
|
||||
**What we learned:** the NVFP4 precision floor for this model is — keep **LM head** BF16, **router gate** BF16, and the **compressor/indexer helper projections** BF16, with the **one exception** that the **CSA indexer QK path stays FP4** (it was explicitly FP4-QATed; the other compressor projections were not, so PTQ-ing them to FP4 breaks). We validated each individually. Now do all of them together, simple-PyTorch first, then native.
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ First: the CUDA illegal-memory-access (you're calling the wrong dequant)
|
||||
|
||||
There are **two** functions with nearly the same name:
|
||||
|
||||
- `single_shot_inference.py:238` — `dequant_nvfp4(weight, weight_scale, weight_scale_2, input_scale)` — **pure PyTorch** (does `weight_scale.repeat_interleave(16,1) * scales`). This is what `nvfp4_linear_ref` uses — your **validated reference**. It cannot cause an illegal access.
|
||||
- `dsv4/ops/quantize.py:377` — `dequantize_nvfp4(x_fp4, x_sf, gsa)` — calls the **CUDA kernel** `dequant_nvfp4.cu`. **This is the one crashing.**
|
||||
|
||||
The precision-floor code (lines 328 / 333 / 426: kv_proj, gate_proj, wp) imports the **CUDA** one and feeds it **weights**. But that kernel was written for the **activation / KV-gather** path — read its own docstring: *"compressed KV is stored as NVFP4, dequantized on-the-fly."* It assumes row-major `(M, N/16)` block scales, per-row `gsa`, `N=512`.
|
||||
|
||||
The host wrapper only does `TORCH_CHECK(sf_data.size(0) == M)` — it validates the scale's **row count and nothing else** (not width, not total size, not contiguity). The kernel then indexes `sf_data[m*(N/16) + n_block]` flat. For a weight whose scale isn't *exactly* contiguous row-major `(M, N/16)` — different width, padding, non-contiguous `.to(dev)` view, or the GEMM swizzle — that index walks off the allocation → **async illegal access, surfacing at the next sync (the compressor load).** The activation/KV path never tripped it because those scales already match the assumed layout.
|
||||
|
||||
**Confirm it in 2 minutes** (the error is async, so do this to localize it):
|
||||
```bash
|
||||
compute-sanitizer --tool memcheck <your harness> ... # will name dequant_nvfp4_kernel + the sf_data read
|
||||
# or: CUDA_LAUNCH_BLOCKING=1 to move the report to the offending launch
|
||||
```
|
||||
And add these guards to `dequant_nvfp4_cuda` in `dequant_nvfp4.cu` — they turn the async crash into an immediate, located error and print the size mismatch:
|
||||
```cpp
|
||||
TORCH_CHECK(fp4_data.is_contiguous() && sf_data.is_contiguous(), "dequant inputs must be contiguous");
|
||||
TORCH_CHECK(sf_data.numel() >= (int64_t)M * (N/16), "sf too small: have ", sf_data.numel(), " need ", (int64_t)M*(N/16));
|
||||
TORCH_CHECK(fp4_data.numel() >= (int64_t)M * (N/2), "fp4 too small: have ", fp4_data.numel(), " need ", (int64_t)M*(N/2));
|
||||
```
|
||||
|
||||
You don't need the CUDA kernel here at all (see PART 1) — these weights are dequanted **once at load**, so there's zero performance reason to use a custom kernel for them.
|
||||
|
||||
---
|
||||
|
||||
## PART 1 — PyTorch quick version (all floor fixes together, simple, no crash)
|
||||
|
||||
Goal: one combined config, pure PyTorch, prove correctness end-to-end. This also sidesteps the OOB by not using the CUDA dequant for weights.
|
||||
|
||||
1. **Swap the three weight-dequant call sites (328/333/426) to the PyTorch reference.** The CUDA `dequantize_nvfp4(kv_w, kv_ws, gsa)` becomes the PyTorch `dequant_nvfp4(kv_w, kv_ws, kv_ws2, kv_isc)` — and you can delete the manual `gsa = torch.tensor([ws2_v]*shape[0])` lines, because the PyTorch version handles `weight_scale_2` / `input_scale` internally. Be explicit about *which* function you import (they're nearly identically named — that's how this got crossed). Example:
|
||||
```python
|
||||
from single_shot_inference import dequant_nvfp4 as dequant_nvfp4_torch # the pure-PyTorch one
|
||||
# kv_proj:
|
||||
self._kv_bf16 = dequant_nvfp4_torch(kv_w.to(dev), kv_ws.to(dev), kv_ws2, kv_isc).to(dev).contiguous()
|
||||
# gate_proj, wp: same pattern
|
||||
```
|
||||
2. **LM head → BF16, router gate → BF16.** Dequant their FP4 weights to BF16 once at load via the same PyTorch path, then run them as plain `F.linear`. (The gate is tiny; the LM head is the only sizable one and it's ~1.4 GB — negligible against the KV/concurrency budget.)
|
||||
3. **Keep the CSA indexer QK path in FP4 — do NOT dequant it.** Only the QK projection of the indexer was QATed. Its non-QATed siblings in the compressor go to BF16 with everything else.
|
||||
4. **Run a clean generation** with the fixed chat template (the official `encoding/encoding_dsv4.py`, not the hand-rolled path). Confirm: coherent, **no repetition loop**, **clean stop**, Paris top-1 on the canonical probe, and run **≥ a few hundred tokens** so HCA actually engages (HCA's first compressed entry only forms at 128 tokens).
|
||||
5. **A/B insurance:** this is the all-at-once config. If it regresses versus the individual fixes, flip one component FP4↔BF16 at a time to find the interaction — and record which ones were necessary (that table is the NVIDIA-writeup evidence).
|
||||
|
||||
---
|
||||
|
||||
## PART 2 — Native CuteDSL / CUDA version
|
||||
|
||||
Only after PART 1 validates the combined config (it becomes your reference for it).
|
||||
|
||||
1. **Fix the weight dequant path** (you have two options; pick one):
|
||||
- *Simplest:* keep dequanting these few weights to BF16 **at load in PyTorch** (PART 1) even in the native build. It's a one-time load op — no hot-path cost — so there's no need to native-ize it at all.
|
||||
- *If you insist on the CUDA kernel for load:* add the `numel`/contiguity guards above, then make the scale match what the kernel reads. The raw checkpoint `weight_scale` appears row-major **before** `finalize_weights` (the production GEMM swizzles at finalize — see the "K-major + swizzle" step ~line 1352 — so the *raw* scale is unswizzled). The guards will tell you if it's actually `(M, N/16)` contiguous; if not, make it contiguous before launch or teach the kernel the real stride. Also: the kernel was built around `N=512`; for weights `N=in` (≈7168) — make sure nothing downstream hardcodes 512.
|
||||
2. **Hot-path natives are unchanged:** FP8 FMHA, FP4 MoE, and the **FP4 CSA indexer QK** all stay as they are. The floor change only touches load-time weight handling + two small GEMMs (gate, lm_head) that run as native **BF16** (cuBLAS/standard), not FP4.
|
||||
3. **Re-validate per-layer cosine** of the native build against the PART 1 PyTorch combined-config reference before declaring done.
|
||||
|
||||
---
|
||||
|
||||
## Guardrails
|
||||
|
||||
- Don't reintroduce the **CUDA** `dequantize_nvfp4` for **weights** until the wrapper guards are in and the scale layout is confirmed — for now the PyTorch dequant is correct and crash-proof.
|
||||
- The two functions `dequant_nvfp4` (PyTorch, weights) and `dequantize_nvfp4` (CUDA, activations/KV) are a foot-gun. Consider renaming the CUDA one to `dequantize_nvfp4_kvcache` so this can't recur.
|
||||
- Only the **CSA indexer QK** path is FP4-QATed — do not let FP4 creep onto its non-QATed siblings.
|
||||
- Validate end-to-end (coherent + non-looping + clean stop + HCA-depth) **before** calling it done.
|
||||
@@ -124,15 +124,13 @@ __global__ void csa_compress_reduce_kernel(
|
||||
|
||||
float g = gate_proj[token_idx * kv_dim + gate_offset + c];
|
||||
float kv_val = kv_proj[token_idx * kv_dim + kv_offset + c];
|
||||
// Position bias: same (m, 2*hd) bias added to every block
|
||||
// Added to BOTH gate (softmax logit) and kv (content) per reference
|
||||
// Position bias added ONLY to gate (softmax logit), NOT to KV content.
|
||||
// Paper eq. 11-12: compressed = softmax(Z + B) * C — bias B is on the
|
||||
// compression weights/logits, not on the KV content C.
|
||||
if (position_bias != nullptr) {
|
||||
int pos_bias_row = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t);
|
||||
if (pos_bias_row >= 0 && pos_bias_row < m) {
|
||||
float pb = position_bias[pos_bias_row * kv_dim + gate_offset + c];
|
||||
g += pb;
|
||||
// kv_offset matches gate_offset for CSA: both are 0 (a-stream) or hd (b-stream)
|
||||
kv_val += position_bias[pos_bias_row * kv_dim + kv_offset + c];
|
||||
g += position_bias[pos_bias_row * kv_dim + gate_offset + c];
|
||||
}
|
||||
}
|
||||
float e = expf(g - local_max[ci]);
|
||||
@@ -192,12 +190,11 @@ __global__ void hca_compress_reduce_kernel(
|
||||
if (token_idx >= T) break;
|
||||
float g = gate_proj[token_idx * hd + c];
|
||||
float kv_val = kv_proj[token_idx * hd + c];
|
||||
// Position bias: same (m, hd) bias added to every block
|
||||
// Added to BOTH gate (softmax logit) and kv (content) per reference
|
||||
// Position bias added ONLY to gate (softmax logit), NOT to KV content.
|
||||
// Paper eq. 9-10: compressed = softmax(Z + B) * C — bias B is on the
|
||||
// compression weights/logits, not on the KV content C.
|
||||
if (position_bias != nullptr && t < m) {
|
||||
float pb = position_bias[t * hd + c];
|
||||
g += pb;
|
||||
kv_val += pb;
|
||||
g += position_bias[t * hd + c];
|
||||
}
|
||||
float e = expf(g - local_max);
|
||||
local_denom += e;
|
||||
|
||||
@@ -2196,12 +2196,11 @@ class FusedSwiGLUScaledGroupedGemmKernel:
|
||||
neg_acc = acc_vec * cutlass.Float32(-1.0)
|
||||
exp_neg = cute.exp(neg_acc)
|
||||
sigmoid = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + exp_neg)
|
||||
silu_result = acc_vec * sigmoid
|
||||
# Paper §4.2.3: gate component capped at swiglu_limit
|
||||
# CuTe DSL clamp: min(x, limit) = cute.where(x > limit, limit, x)
|
||||
# Paper §4.2.3: clamp raw gate BEFORE SiLU, not after
|
||||
if cutlass.const_expr(self.swiglu_limit > 0.0):
|
||||
limit = cutlass.Float32(self.swiglu_limit)
|
||||
silu_result = cute.where(silu_result > limit, limit, silu_result)
|
||||
acc_vec = cute.where(acc_vec > limit, limit, acc_vec)
|
||||
silu_result = acc_vec * sigmoid
|
||||
silu_result = silu_result.to(self.c_dtype)
|
||||
silu_gate_buf.store(silu_result)
|
||||
# Keep acc_vec in BF16 (same type as the up branch)
|
||||
|
||||
@@ -512,10 +512,11 @@ class Nvfp4MoE:
|
||||
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0]
|
||||
gate = l1_deil[:, :self.intermediate_size]
|
||||
up = l1_deil[:, self.intermediate_size:]
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
# Paper §4.2.3: clamp raw gate BEFORE SiLU, not after
|
||||
if self._swiglu_limit is not None:
|
||||
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
|
||||
gate = gate.clamp(max=self._swiglu_limit)
|
||||
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
activated = gate_silu * up
|
||||
_, _, l2_gs = quantize_to_nvfp4(activated)
|
||||
|
||||
@@ -651,10 +652,11 @@ class Nvfp4MoE:
|
||||
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0]
|
||||
gate = l1_deil[:, :self.intermediate_size]
|
||||
up = l1_deil[:, self.intermediate_size:]
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
# Paper §4.2.3: clamp raw gate BEFORE SiLU, not after
|
||||
if self._swiglu_limit is not None:
|
||||
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
|
||||
gate = gate.clamp(max=self._swiglu_limit)
|
||||
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
activated = gate_silu * up
|
||||
|
||||
# Compute runtime gsa for L2 from activated output (non-fused path)
|
||||
|
||||
@@ -9,7 +9,6 @@ NO PyTorch SDPA fallback. NO dequant+matmul for production projections.
|
||||
This is the ground truth for vLLM / SGLang integration.
|
||||
"""
|
||||
import os, sys, time, json, math, argparse, logging
|
||||
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # Catch async CUDA errors immediately
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pathlib import Path
|
||||
@@ -303,8 +302,6 @@ class Compressor:
|
||||
self.is_csa = (ratio == 4); self.kv_dim = 2 * head_dim if self.is_csa else head_dim
|
||||
self.kv_lin = None # production Nvfp4Linear for kv_proj
|
||||
self.gate_lin = None # production Nvfp4Linear for gate_proj
|
||||
self._kv_bf16 = None # BF16 weight for kv_proj (dequantized from NVFP4)
|
||||
self._gate_bf16 = None # BF16 weight for gate_proj (dequantized from NVFP4)
|
||||
self.ape = None; self.kv_norm_w = None
|
||||
self._reduce_loaded = False
|
||||
# P7: Decode buffering — accumulate hidden_states until we have a complete block.
|
||||
@@ -315,24 +312,26 @@ class Compressor:
|
||||
self._buf_len = 0
|
||||
|
||||
def load(self, w, pfx, dev=None):
|
||||
"""Load weights and build BF16 projections (dequantized from NVFP4)."""
|
||||
"""Load weights and build production Nvfp4Linear instances."""
|
||||
if dev is None: dev = self.device
|
||||
# Compressor projections are NOT explicitly FP4-QATed — dequant to BF16, use F.linear
|
||||
# CRITICAL: Use the PyTorch dequant_nvfp4 (defined in this file), NOT the CUDA
|
||||
# dequantize_nvfp4 from dsv4/ops/quantize.py. The CUDA kernel assumes
|
||||
# activation/KV scale layout (row-major (M, N/16)) and crashes on weight scales
|
||||
# that don't match — async illegal memory access surfaces at next sync.
|
||||
# Build production NVFP4 GEMM instances for the two projections
|
||||
# kv_proj: in=7168, out=kv_dim (1024 for CSA, 512 for HCA)
|
||||
# gate_proj: same shapes
|
||||
kv_w, kv_ws, kv_ws2, kv_isc = get_nvfp4_weight(w, pfx, 'kv_proj')
|
||||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(w, pfx, 'gate_proj')
|
||||
if kv_w is not None:
|
||||
self._kv_bf16 = dequant_nvfp4(kv_w.to(dev), kv_ws.to(dev), kv_ws2, kv_isc).to(dev).contiguous()
|
||||
kv_out = kv_w.shape[0] # N_packed
|
||||
kv_in = kv_w.shape[1] * 2 # K_packed * 2
|
||||
self.kv_lin = make_nvfp4_linear(kv_in, kv_out, dev, w, pfx, 'kv_proj')
|
||||
if gate_w is not None:
|
||||
self._gate_bf16 = dequant_nvfp4(gate_w.to(dev), gate_ws.to(dev), gate_ws2, gate_isc).to(dev).contiguous()
|
||||
gate_out = gate_w.shape[0]
|
||||
gate_in = gate_w.shape[1] * 2
|
||||
self.gate_lin = make_nvfp4_linear(gate_in, gate_out, dev, w, pfx, 'gate_proj')
|
||||
self.ape = w.get(f"{pfx}.position_bias")
|
||||
self.kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
|
||||
|
||||
def forward(self, hidden_states, positions):
|
||||
if self.ratio == 0 or self._kv_bf16 is None: return None, None, None
|
||||
if self.ratio == 0 or self.kv_lin is None: return None, None, None
|
||||
T = hidden_states.shape[0]; r = self.ratio; dev = hidden_states.device
|
||||
|
||||
# P7: Buffer decode steps until we have a complete block.
|
||||
@@ -359,9 +358,9 @@ class Compressor:
|
||||
n_complete = T // r
|
||||
if n_complete == 0: return None, None, None
|
||||
|
||||
# Step 1-2: BF16 F.linear projections → FP32 for compress
|
||||
kv = torch.nn.functional.linear(hidden_states, self._kv_bf16).float() # (T, kv_dim) FP32
|
||||
gate = torch.nn.functional.linear(hidden_states, self._gate_bf16).float() # (T, kv_dim) FP32
|
||||
# Step 1-2: NVFP4 GEMM projections → FP32 for compress
|
||||
kv = self.kv_lin(hidden_states).float() # (T, kv_dim) FP32
|
||||
gate = self.gate_lin(hidden_states).float() # (T, kv_dim) FP32
|
||||
|
||||
# Step 3: CUDA softmax/reduce kernel → FP32
|
||||
# KV-1/KV-2: Return FP32. Caller applies RoPE, then quantizes to NVFP4.
|
||||
@@ -399,23 +398,22 @@ class Indexer:
|
||||
"""
|
||||
def __init__(self, n_ih, ihd, top_k, device):
|
||||
self.n_ih, self.ihd, self.top_k, self.device = n_ih, ihd, top_k, device
|
||||
self.q_b_lin = None # production Nvfp4Linear for q_b_proj (FP4-QATed)
|
||||
self._wp_bf16 = None # BF16 weight for weights_proj (dequantized from NVFP4)
|
||||
self.q_b_lin = None # production Nvfp4Linear for q_b_proj
|
||||
self.wp_lin = None # production Nvfp4Linear for weights_proj
|
||||
self.compressor = None
|
||||
|
||||
def load(self, w, pfx, dev=None):
|
||||
if dev is None: dev = self.device
|
||||
qb_w, qb_ws, qb_ws2, qb_isc = get_nvfp4_weight(w, pfx, 'q_b_proj')
|
||||
wp_w, wp_ws, wp_ws2, wp_isc = get_nvfp4_weight(w, pfx, 'weights_proj')
|
||||
# q_b_proj IS the FP4-QATed QK path — keep as NVFP4
|
||||
if qb_w is not None:
|
||||
qb_out = qb_w.shape[0]
|
||||
qb_in = qb_w.shape[1] * 2
|
||||
self.q_b_lin = make_nvfp4_linear(qb_in, qb_out, dev, w, pfx, 'q_b_proj')
|
||||
# weights_proj is NOT FP4-QATed — dequant to BF16 via PyTorch reference
|
||||
# CRITICAL: Use PyTorch dequant_nvfp4, NOT CUDA dequantize_nvfp4 (see Compressor.load)
|
||||
if wp_w is not None:
|
||||
self._wp_bf16 = dequant_nvfp4(wp_w.to(dev), wp_ws.to(dev), wp_ws2, wp_isc).to(dev).contiguous()
|
||||
wp_out = wp_w.shape[0]
|
||||
wp_in = wp_w.shape[1] * 2
|
||||
self.wp_lin = make_nvfp4_linear(wp_in, wp_out, dev, w, pfx, 'weights_proj')
|
||||
# Indexer compressor weights are directly under the indexer prefix
|
||||
# (e.g. *.indexer.kv_proj.weight), NOT nested under *.indexer.compressor.
|
||||
if f"{pfx}.kv_proj.weight" in w:
|
||||
@@ -438,7 +436,7 @@ class Indexer:
|
||||
li = layer_idx
|
||||
|
||||
q_idx = self.q_b_lin(q_lora).reshape(T, self.n_ih, self.ihd) # (T, n_ih, ihd)
|
||||
w_h = torch.nn.functional.linear(hidden_states, self._wp_bf16) # (T, n_ih) BF16
|
||||
w_h = self.wp_lin(hidden_states) # (T, n_ih)
|
||||
|
||||
# B2: FP8 tensor-core scoring path.
|
||||
# Indexer keys are stored as FP8_E4M3 in the KV cache.
|
||||
@@ -1308,26 +1306,50 @@ def main():
|
||||
router.load_weights(hash_lut=all_w[f"{pfx}.gate.tid2eid"].to(dev, torch.int32))
|
||||
else:
|
||||
eb = all_w.get(f"{pfx}.gate.e_score_correction_bias")
|
||||
# BF16 router gate — dequantize NVFP4 to BF16, use F.linear
|
||||
E = cfg["n_routed_experts"]
|
||||
# NVFP4 production GEMM for router gate
|
||||
# Custom CuTeDSL fused kernel crashes MLIR optimizer,
|
||||
# so we use Nvfp4Linear (proven production path).
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate')
|
||||
E = cfg["n_routed_experts"]
|
||||
if gate_w is not None and gate_ws is not None:
|
||||
# Checkpoint has NVFP4 gate weight — dequantize to BF16
|
||||
# CRITICAL: Use PyTorch dequant_nvfp4, NOT CUDA dequantize_nvfp4
|
||||
# (same fix as Compressor.load — CUDA kernel crashes on weight scale layouts)
|
||||
gate_bf16 = dequant_nvfp4(gate_w.to(dev), gate_ws.to(dev), gate_ws2, gate_isc)
|
||||
router.W_gate = gate_bf16.T.contiguous().to(dev) # (H, E) for F.linear(x, W_gate.T)
|
||||
# Checkpoint has NVFP4 gate weight (N_packed, K_packed) — correct layout
|
||||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||||
gate_w_view = gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev)
|
||||
gate_lin.fp4 = [gate_w_view]
|
||||
gate_lin.sf = [gate_ws.to(dev)]
|
||||
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
|
||||
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
|
||||
gate_lin.gs = [1.0]
|
||||
gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = isc_v # placeholder — runtime gsa overrides this
|
||||
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
|
||||
gate_lin.finalize_weights()
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
if li < 5: print(f" L{li}: NVFP4 router gate (checkpoint)", flush=True)
|
||||
else:
|
||||
# BF16 gate weight from checkpoint
|
||||
# BF16 gate weight: quantize to NVFP4
|
||||
gw = all_w.get(f"{pfx}.gate.weight")
|
||||
gate_bf16 = gw.bfloat16().to(dev)
|
||||
if gate_bf16.shape[0] != H:
|
||||
gate_bf16 = gate_bf16.T.contiguous() # ensure (H, E)
|
||||
router.W_gate = gate_bf16.contiguous()
|
||||
# No gate_lin — force BF16 dispatch path
|
||||
router.gate_lin = None
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
if li < 5: print(f" L{li}: BF16 router gate (dequantized from NVFP4)", flush=True)
|
||||
if gw is not None:
|
||||
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
|
||||
g_bf16 = g_bf16.bfloat16().to(dev)
|
||||
from dsv4.ops.quantize import quantize_to_nvfp4
|
||||
g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16)
|
||||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||||
gate_lin.fp4 = [g_fp4]
|
||||
gate_lin.sf = [g_sf]
|
||||
gate_lin.gs = [g_gs]
|
||||
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder — runtime gsa overrides
|
||||
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
|
||||
gate_lin.finalize_weights()
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
if li < 5: print(f" L{li}: NVFP4 router gate (quantized, gs={g_gs:.6f})", flush=True)
|
||||
else:
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
router.finalize_weights(); routers[li] = router
|
||||
|
||||
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
|
||||
@@ -1375,11 +1397,21 @@ def main():
|
||||
torch.cuda.set_device(0)
|
||||
embed_w = all_w.get("model.embed_tokens.weight")
|
||||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0'))
|
||||
# lm_head: BF16 GEMM (checkpoint weight is BF16, no quantization)
|
||||
# lm_head: NVFP4 production GEMM
|
||||
lm_w_raw = all_w.get("lm_head.weight", embed_w).bfloat16().to('cuda:0')
|
||||
lm_head_lin = None # Use raw BF16 F.linear for lm_head
|
||||
lm_w = lm_w_raw # Keep as (V, H) BF16 for F.linear
|
||||
print(" lm_head: BF16 GEMM (checkpoint weight, no quantization)")
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
lm_head_lin = Nvfp4Linear(lm_w_raw.shape[1], lm_w_raw.shape[0], max_num_tokens=8192, device='cuda:0')
|
||||
from dsv4.ops.quantize import quantize_weight_to_nvfp4
|
||||
lm_fp4, lm_sf, lm_gs = quantize_weight_to_nvfp4(lm_w_raw.T.contiguous())
|
||||
lm_head_lin.fp4 = [lm_fp4.permute(1, 0).contiguous()]
|
||||
lm_head_lin.sf = [lm_sf.permute(1, 0).contiguous()]
|
||||
lm_head_lin.gs = [lm_gs]
|
||||
lm_head_lin.ws2 = [None]
|
||||
lm_head_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
lm_head_lin._use_runtime_gsa = True
|
||||
lm_head_lin.finalize_weights()
|
||||
lm_w = None
|
||||
print(" lm_head: NVFP4 production GEMM")
|
||||
final_norm_w = all_w.get("model.norm.weight")
|
||||
if final_norm_w is not None: final_norm_w = final_norm_w.to('cuda:0', torch.float32)
|
||||
|
||||
@@ -1628,8 +1660,8 @@ def main():
|
||||
gl._activation_global_scale = fixed_gsa
|
||||
gl._use_runtime_gsa = False
|
||||
n_fixed += 1
|
||||
# lm_head (BF16 — no gsa needed)
|
||||
if lm_head_lin is not None and hasattr(lm_head_lin, '_gsa_buf') and hasattr(lm_head_lin, '_use_runtime_gsa') and lm_head_lin._use_runtime_gsa:
|
||||
# lm_head
|
||||
if hasattr(lm_head_lin, '_gsa_buf') and hasattr(lm_head_lin, '_use_runtime_gsa') and lm_head_lin._use_runtime_gsa:
|
||||
fixed_gsa = lm_head_lin._gsa_buf.item()
|
||||
lm_head_lin._activation_global_scale = fixed_gsa
|
||||
lm_head_lin._use_runtime_gsa = False
|
||||
@@ -1637,7 +1669,7 @@ def main():
|
||||
print(f" Warmup gsa: fixed {n_fixed} projection gsa values from step 0 (MoE/SE keep runtime gsa)", flush=True)
|
||||
x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :]
|
||||
if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w)
|
||||
logits = torch.nn.functional.linear(x_out, lm_w) if lm_head_lin is None else lm_head_lin(x_out)
|
||||
logits = lm_head_lin(x_out)
|
||||
if profile: torch.cuda.synchronize()
|
||||
t_lm = time.perf_counter()
|
||||
# Check thinking start token logit on first step
|
||||
|
||||
Reference in New Issue
Block a user