Fix compressor: do not add positional bias to KV content
The positional bias (ape/B) should only modulate the compression softmax logits (Z + B), NOT be added to the KV content itself. Paper equation: compressed = softmax(Z + B) · C Bug was doing: compressed = softmax(Z + B) · (C + B) — poisons every compressed KV entry with learned positional-bias content. Fixed in both CSA (compress_csa_reduce_kernel) and HCA (hca_compress_reduce_kernel) paths in compressor_reduce.cu.
This commit is contained in:
@@ -124,15 +124,14 @@ __global__ void csa_compress_reduce_kernel(
|
||||
|
||||
float g = gate_proj[token_idx * kv_dim + gate_offset + c];
|
||||
float kv_val = kv_proj[token_idx * kv_dim + kv_offset + c];
|
||||
// Position bias: same (m, 2*hd) bias added to every block
|
||||
// Added to BOTH gate (softmax logit) and kv (content) per reference
|
||||
// Position bias: added to gate logits (softmax Z + B) only.
|
||||
// The paper defines compression as softmax(Z + B) then weighted sum of C.
|
||||
// The bias must NOT be added to kv_val — that poisons compressed content.
|
||||
if (position_bias != nullptr) {
|
||||
int pos_bias_row = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t);
|
||||
if (pos_bias_row >= 0 && pos_bias_row < m) {
|
||||
float pb = position_bias[pos_bias_row * kv_dim + gate_offset + c];
|
||||
g += pb;
|
||||
// kv_offset matches gate_offset for CSA: both are 0 (a-stream) or hd (b-stream)
|
||||
kv_val += position_bias[pos_bias_row * kv_dim + kv_offset + c];
|
||||
}
|
||||
}
|
||||
float e = expf(g - local_max[ci]);
|
||||
@@ -192,12 +191,12 @@ __global__ void hca_compress_reduce_kernel(
|
||||
if (token_idx >= T) break;
|
||||
float g = gate_proj[token_idx * hd + c];
|
||||
float kv_val = kv_proj[token_idx * hd + c];
|
||||
// Position bias: same (m, hd) bias added to every block
|
||||
// Added to BOTH gate (softmax logit) and kv (content) per reference
|
||||
// Position bias: added to gate logits (softmax Z + B) only.
|
||||
// The paper defines compression as softmax(Z + B) then weighted sum of C.
|
||||
// The bias must NOT be added to kv_val — that poisons compressed content.
|
||||
if (position_bias != nullptr && t < m) {
|
||||
float pb = position_bias[t * hd + c];
|
||||
g += pb;
|
||||
kv_val += pb;
|
||||
}
|
||||
float e = expf(g - local_max);
|
||||
local_denom += e;
|
||||
|
||||
Reference in New Issue
Block a user