Fix compressor: do not add positional bias to KV content

The positional bias (ape/B) should only modulate the compression
softmax logits (Z + B), NOT be added to the KV content itself.

Paper equation: compressed = softmax(Z + B) · C
Bug was doing: compressed = softmax(Z + B) · (C + B) — poisons every
compressed KV entry with learned positional-bias content.

Fixed in both CSA (compress_csa_reduce_kernel) and HCA
(hca_compress_reduce_kernel) paths in compressor_reduce.cu.
This commit is contained in:
2026-06-03 15:52:00 +00:00
parent 4fe73fe713
commit ca5bc814d5

View File

@@ -124,15 +124,14 @@ __global__ void csa_compress_reduce_kernel(
float g = gate_proj[token_idx * kv_dim + gate_offset + c];
float kv_val = kv_proj[token_idx * kv_dim + kv_offset + c];
// Position bias: same (m, 2*hd) bias added to every block
// Added to BOTH gate (softmax logit) and kv (content) per reference
// Position bias: added to gate logits (softmax Z + B) only.
// The paper defines compression as softmax(Z + B) then weighted sum of C.
// The bias must NOT be added to kv_val — that poisons compressed content.
if (position_bias != nullptr) {
int pos_bias_row = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t);
if (pos_bias_row >= 0 && pos_bias_row < m) {
float pb = position_bias[pos_bias_row * kv_dim + gate_offset + c];
g += pb;
// kv_offset matches gate_offset for CSA: both are 0 (a-stream) or hd (b-stream)
kv_val += position_bias[pos_bias_row * kv_dim + kv_offset + c];
}
}
float e = expf(g - local_max[ci]);
@@ -192,12 +191,12 @@ __global__ void hca_compress_reduce_kernel(
if (token_idx >= T) break;
float g = gate_proj[token_idx * hd + c];
float kv_val = kv_proj[token_idx * hd + c];
// Position bias: same (m, hd) bias added to every block
// Added to BOTH gate (softmax logit) and kv (content) per reference
// Position bias: added to gate logits (softmax Z + B) only.
// The paper defines compression as softmax(Z + B) then weighted sum of C.
// The bias must NOT be added to kv_val — that poisons compressed content.
if (position_bias != nullptr && t < m) {
float pb = position_bias[t * hd + c];
g += pb;
kv_val += pb;
}
float e = expf(g - local_max);
local_denom += e;