Fix set_indexer_keys_fp8 None guard + store comp_pos in mixed storage
This commit is contained in:
@@ -540,28 +540,31 @@ class KVCache:
|
||||
self.swa_head = (self.swa_head + T) % self.ws
|
||||
self.swa_len = min(self.swa_len + T, self.ws)
|
||||
|
||||
def set_compressed_mixed(self, nope_fp8, nope_scale, rope_bf16):
|
||||
def set_compressed_mixed(self, nope_fp8, nope_scale, rope_bf16, comp_pos=None):
|
||||
"""Add compressed KV entries (mixed FP8 nope + BF16 rope)."""
|
||||
T = nope_fp8.shape[0]
|
||||
end = self.n_comp
|
||||
self.comp_nope_fp8[end:end+T] = nope_fp8.view(torch.uint8) if nope_fp8.dtype != torch.uint8 else nope_fp8
|
||||
self.comp_nope_scale[end:end+T] = nope_scale
|
||||
self.comp_rope_bf16[end:end+T] = rope_bf16
|
||||
if comp_pos is not None:
|
||||
self.comp_pos_buf[end:end+T] = comp_pos
|
||||
self.n_comp = end + T
|
||||
|
||||
def set_indexer_keys_fp8(self, idx_kv):
|
||||
"""Add indexer compressed keys. idx_kv is BF16 (n_comp, ihd) or FP8 (fp8, scale)."""
|
||||
T = self.n_comp # should match compressed KV count
|
||||
end = T - (idx_kv[0].shape[0] if isinstance(idx_kv, tuple) else idx_kv.shape[0])
|
||||
if idx_kv is None: return
|
||||
T = idx_kv[0].shape[0] if isinstance(idx_kv, tuple) else idx_kv.shape[0]
|
||||
end = self.n_comp - T
|
||||
if isinstance(idx_kv, tuple) and len(idx_kv) == 2:
|
||||
fp8, scale = idx_kv
|
||||
self.comp_idx_fp8[end:end+T-end] = fp8.view(torch.uint8) if fp8.dtype != torch.uint8 else fp8
|
||||
self.comp_idx_scale[end:end+T-end] = scale
|
||||
self.comp_idx_fp8[end:end+T] = fp8.view(torch.uint8) if fp8.dtype != torch.uint8 else fp8
|
||||
self.comp_idx_scale[end:end+T] = scale
|
||||
elif isinstance(idx_kv, torch.Tensor):
|
||||
mod = self._get_kv_quant_mod()
|
||||
fp8, scale = mod.quantize_fp8_e4m3_from_fp32(idx_kv.float().contiguous())
|
||||
self.comp_idx_fp8[end:self.n_comp] = fp8.view(torch.uint8)
|
||||
self.comp_idx_scale[end:self.n_comp] = scale
|
||||
self.comp_idx_fp8[end:end+T] = fp8.view(torch.uint8)
|
||||
self.comp_idx_scale[end:end+T] = scale
|
||||
self._has_idx = True
|
||||
|
||||
def comp_nope_selective(self, indices):
|
||||
@@ -709,8 +712,8 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
rope_bf16 = rope_3d.squeeze(1) # (n_comp, 64) BF16
|
||||
# Quantize non-RoPE part FP32 → FP8_E4M3
|
||||
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
||||
# Store mixed-format compressed KV
|
||||
kv_cache.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16)
|
||||
# Store mixed-format compressed KV + positions
|
||||
kv_cache.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16, comp_pos)
|
||||
if compressor.is_csa and indexer is not None and indexer.compressor is not None:
|
||||
comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, positions)
|
||||
# Indexer keys: FP8_E4M3 (ihd=128, no RoPE)
|
||||
|
||||
Reference in New Issue
Block a user