Fix two correctness bugs: compressor pos bias on KV + SwiGLU clamp ordering
1. Compressor positional bias was being added to BOTH gate (softmax logit) AND KV content. Per paper eq. 9-12, position bias is only for the softmax logits (Z+B), NOT the KV content (C). Adding pb to kv_val corrupts every compressed KV entry with learned positional-bias content. Fixed in both CSA and HCA paths in compressor_reduce.cu. 2. SwiGLU clamp ordering: code was clamping silu(gate) instead of clamping raw gate before SiLU. Per paper §4.2.3: gate = clamp(gate, max=limit), then silu(clamp(gate)) * clamp(up). Fixed in moe.py (both unfused paths) and fused_swiglu.py (CuTeDSL kernel). shared_expert.py was already correct.
This commit is contained in:
59
CORRECTNESS_FIX_ATTEMPTS.md
Normal file
59
CORRECTNESS_FIX_ATTEMPTS.md
Normal file
@@ -0,0 +1,59 @@
|
||||
## 1. Possible bug: compressor positional bias is being added to KV content
|
||||
|
||||
In your `dsv4/kernels/cuda/compressor_reduce.cu`, the compressor appears to do this in both CSA and HCA paths:
|
||||
|
||||
```cpp
|
||||
g += pb;
|
||||
kv_val += pb; // suspicious / wrong
|
||||
```
|
||||
|
||||
The official compressor equations add positional bias only to the **compression weights/logits** `Z + B`, then use those weights to sum the raw projected KV content `C`. The bias is not added to the KV value itself. The paper defines compression as softmax over `Z + B`, followed by a weighted sum of `C`.
|
||||
|
||||
So this should be:
|
||||
|
||||
```cpp
|
||||
g += pb;
|
||||
// do not add pb to kv_val
|
||||
```
|
||||
|
||||
That bug would poison every compressed KV entry with learned positional-bias content. It may not fully explain the first token for a tiny prompt if SWA dominates, but it is absolutely wrong relative to the official architecture and will degrade CSA/HCA context quality. If your unit tests passed, they may have been comparing against a reference that made the same mistake or were too short to expose it.
|
||||
|
||||
## 2. Don’t use `think_start` as the canary here
|
||||
|
||||
In official `thinking` mode, the prompt formatter typically appends the assistant marker plus `<think>` before generation. That means decode step 0 is already *inside* the thinking span. The model should not necessarily emit `think_start`; a low `think_start` logit is not itself evidence that the model “failed to enter thinking mode.”
|
||||
|
||||
For this particular prompt, a high `think_end` logit can even be plausible because “The capital of France is” does not need much reasoning. Run the same current kernel with official **chat mode**, greedy decoding, no repetition penalty, no top-k/top-p, and compare first-token logits. The `think_start = 1.77` observation is probably a misleading diagnostic.
|
||||
|
||||
## 3. Indexer parity still looks suspect
|
||||
|
||||
The official CSA/HCA details include RMSNorm on queries and compressed KV before attention, partial RoPE on the last 64 dims, sliding-window KV, and attention sink behavior.
|
||||
|
||||
For the CSA **indexer**, the official reference does more than a plain q/k dot product: indexer Q and compressed indexer K get the appropriate RoPE/rotation treatment, and the QK path is one of the explicitly FP4-QATed pieces. If your current indexer compressor is just producing compressed keys without the same rotate/RoPE/FP4 path, CSA top-k can select plausible-looking but wrong blocks. Again, probably not the first-token issue on a short prompt, but it will matter for any real context.
|
||||
|
||||
## 4. Check SwiGLU clamp ordering
|
||||
|
||||
The official behavior is effectively:
|
||||
|
||||
```python
|
||||
gate = clamp(gate, max=swiglu_limit)
|
||||
up = clamp(up, -swiglu_limit, swiglu_limit)
|
||||
out = silu(gate) * up
|
||||
```
|
||||
|
||||
If your fused path clamps `silu(gate)` instead of clamping raw `gate` before SiLU, it is not equivalent. This is especially worth checking in both routed MoE and shared expert fused kernels, because a small-looking activation semantic mismatch repeats through every layer.
|
||||
|
||||
|
||||
## 5. DEQUANT TO BF16 IN THIS ORDER JUST TO SEE WHAT HAPPENS (You are allowed to break the no bf16 rule for this because we can always revert back to previous commit)
|
||||
|
||||
|
||||
The most suspicious surfaces to temporarily dequantize are, in order:
|
||||
|
||||
1. **lm head** — FP4 lm head can directly flatten or reorder vocabulary logits.
|
||||
2. **router gate** — slight errors can change top-6 experts; wrong expert IDs are much worse than a small GEMM error.
|
||||
3. **shared expert** — official routed experts are the FP4 target; shared expert is a different sensitivity profile.
|
||||
4. **attention q/kv/o projections and grouped output projection** — these are not described as full FP4 QAT targets.
|
||||
5. **compressor/indexer helper projections** — only the CSA indexer QK path is explicitly FP4-QATed, not the whole compressor.
|
||||
|
||||
If a BF16/FP8 lm head alone makes `Paris` / `.` / answer-like tokens dominate again, you’ve found a high-leverage culprit. My money is on LM Head needing to be BF16
|
||||
|
||||
The fastest triage is basically: run `thinking_mode=chat`, greedy; switch only `lm_head` back to BF16/FP8; then switch router back; then patch the compressor bias-to-KV bug. If any one of those sharply separates the first-token distribution, you’ll know where to spend kernel time.
|
||||
@@ -124,15 +124,13 @@ __global__ void csa_compress_reduce_kernel(
|
||||
|
||||
float g = gate_proj[token_idx * kv_dim + gate_offset + c];
|
||||
float kv_val = kv_proj[token_idx * kv_dim + kv_offset + c];
|
||||
// Position bias: same (m, 2*hd) bias added to every block
|
||||
// Added to BOTH gate (softmax logit) and kv (content) per reference
|
||||
// Position bias added ONLY to gate (softmax logit), NOT to KV content.
|
||||
// Paper eq. 11-12: compressed = softmax(Z + B) * C — bias B is on the
|
||||
// compression weights/logits, not on the KV content C.
|
||||
if (position_bias != nullptr) {
|
||||
int pos_bias_row = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t);
|
||||
if (pos_bias_row >= 0 && pos_bias_row < m) {
|
||||
float pb = position_bias[pos_bias_row * kv_dim + gate_offset + c];
|
||||
g += pb;
|
||||
// kv_offset matches gate_offset for CSA: both are 0 (a-stream) or hd (b-stream)
|
||||
kv_val += position_bias[pos_bias_row * kv_dim + kv_offset + c];
|
||||
g += position_bias[pos_bias_row * kv_dim + gate_offset + c];
|
||||
}
|
||||
}
|
||||
float e = expf(g - local_max[ci]);
|
||||
@@ -192,12 +190,11 @@ __global__ void hca_compress_reduce_kernel(
|
||||
if (token_idx >= T) break;
|
||||
float g = gate_proj[token_idx * hd + c];
|
||||
float kv_val = kv_proj[token_idx * hd + c];
|
||||
// Position bias: same (m, hd) bias added to every block
|
||||
// Added to BOTH gate (softmax logit) and kv (content) per reference
|
||||
// Position bias added ONLY to gate (softmax logit), NOT to KV content.
|
||||
// Paper eq. 9-10: compressed = softmax(Z + B) * C — bias B is on the
|
||||
// compression weights/logits, not on the KV content C.
|
||||
if (position_bias != nullptr && t < m) {
|
||||
float pb = position_bias[t * hd + c];
|
||||
g += pb;
|
||||
kv_val += pb;
|
||||
g += position_bias[t * hd + c];
|
||||
}
|
||||
float e = expf(g - local_max);
|
||||
local_denom += e;
|
||||
|
||||
@@ -2196,12 +2196,11 @@ class FusedSwiGLUScaledGroupedGemmKernel:
|
||||
neg_acc = acc_vec * cutlass.Float32(-1.0)
|
||||
exp_neg = cute.exp(neg_acc)
|
||||
sigmoid = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + exp_neg)
|
||||
silu_result = acc_vec * sigmoid
|
||||
# Paper §4.2.3: gate component capped at swiglu_limit
|
||||
# CuTe DSL clamp: min(x, limit) = cute.where(x > limit, limit, x)
|
||||
# Paper §4.2.3: clamp raw gate BEFORE SiLU, not after
|
||||
if cutlass.const_expr(self.swiglu_limit > 0.0):
|
||||
limit = cutlass.Float32(self.swiglu_limit)
|
||||
silu_result = cute.where(silu_result > limit, limit, silu_result)
|
||||
acc_vec = cute.where(acc_vec > limit, limit, acc_vec)
|
||||
silu_result = acc_vec * sigmoid
|
||||
silu_result = silu_result.to(self.c_dtype)
|
||||
silu_gate_buf.store(silu_result)
|
||||
# Keep acc_vec in BF16 (same type as the up branch)
|
||||
|
||||
@@ -512,10 +512,11 @@ class Nvfp4MoE:
|
||||
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0]
|
||||
gate = l1_deil[:, :self.intermediate_size]
|
||||
up = l1_deil[:, self.intermediate_size:]
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
# Paper §4.2.3: clamp raw gate BEFORE SiLU, not after
|
||||
if self._swiglu_limit is not None:
|
||||
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
|
||||
gate = gate.clamp(max=self._swiglu_limit)
|
||||
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
activated = gate_silu * up
|
||||
_, _, l2_gs = quantize_to_nvfp4(activated)
|
||||
|
||||
@@ -651,10 +652,11 @@ class Nvfp4MoE:
|
||||
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0]
|
||||
gate = l1_deil[:, :self.intermediate_size]
|
||||
up = l1_deil[:, self.intermediate_size:]
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
# Paper §4.2.3: clamp raw gate BEFORE SiLU, not after
|
||||
if self._swiglu_limit is not None:
|
||||
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
|
||||
gate = gate.clamp(max=self._swiglu_limit)
|
||||
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
activated = gate_silu * up
|
||||
|
||||
# Compute runtime gsa for L2 from activated output (non-fused path)
|
||||
|
||||
Reference in New Issue
Block a user