Fix set_indexer_keys_fp8 None guard + store comp_pos in mixed storage

This commit is contained in:
2026-06-02 10:20:26 +00:00
parent 1f69f61363
commit c89762ecdd

View File

@@ -540,28 +540,31 @@ class KVCache:
self.swa_head = (self.swa_head + T) % self.ws
self.swa_len = min(self.swa_len + T, self.ws)
def set_compressed_mixed(self, nope_fp8, nope_scale, rope_bf16):
def set_compressed_mixed(self, nope_fp8, nope_scale, rope_bf16, comp_pos=None):
"""Add compressed KV entries (mixed FP8 nope + BF16 rope)."""
T = nope_fp8.shape[0]
end = self.n_comp
self.comp_nope_fp8[end:end+T] = nope_fp8.view(torch.uint8) if nope_fp8.dtype != torch.uint8 else nope_fp8
self.comp_nope_scale[end:end+T] = nope_scale
self.comp_rope_bf16[end:end+T] = rope_bf16
if comp_pos is not None:
self.comp_pos_buf[end:end+T] = comp_pos
self.n_comp = end + T
def set_indexer_keys_fp8(self, idx_kv):
"""Add indexer compressed keys. idx_kv is BF16 (n_comp, ihd) or FP8 (fp8, scale)."""
T = self.n_comp # should match compressed KV count
end = T - (idx_kv[0].shape[0] if isinstance(idx_kv, tuple) else idx_kv.shape[0])
if idx_kv is None: return
T = idx_kv[0].shape[0] if isinstance(idx_kv, tuple) else idx_kv.shape[0]
end = self.n_comp - T
if isinstance(idx_kv, tuple) and len(idx_kv) == 2:
fp8, scale = idx_kv
self.comp_idx_fp8[end:end+T-end] = fp8.view(torch.uint8) if fp8.dtype != torch.uint8 else fp8
self.comp_idx_scale[end:end+T-end] = scale
self.comp_idx_fp8[end:end+T] = fp8.view(torch.uint8) if fp8.dtype != torch.uint8 else fp8
self.comp_idx_scale[end:end+T] = scale
elif isinstance(idx_kv, torch.Tensor):
mod = self._get_kv_quant_mod()
fp8, scale = mod.quantize_fp8_e4m3_from_fp32(idx_kv.float().contiguous())
self.comp_idx_fp8[end:self.n_comp] = fp8.view(torch.uint8)
self.comp_idx_scale[end:self.n_comp] = scale
self.comp_idx_fp8[end:end+T] = fp8.view(torch.uint8)
self.comp_idx_scale[end:end+T] = scale
self._has_idx = True
def comp_nope_selective(self, indices):
@@ -709,8 +712,8 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
rope_bf16 = rope_3d.squeeze(1) # (n_comp, 64) BF16
# Quantize non-RoPE part FP32 → FP8_E4M3
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
# Store mixed-format compressed KV
kv_cache.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16)
# Store mixed-format compressed KV + positions
kv_cache.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16, comp_pos)
if compressor.is_csa and indexer is not None and indexer.compressor is not None:
comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, positions)
# Indexer keys: FP8_E4M3 (ihd=128, no RoPE)