diff --git a/single_shot_inference.py b/single_shot_inference.py index f4f3d19b..6a45b263 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -540,28 +540,31 @@ class KVCache: self.swa_head = (self.swa_head + T) % self.ws self.swa_len = min(self.swa_len + T, self.ws) - def set_compressed_mixed(self, nope_fp8, nope_scale, rope_bf16): + def set_compressed_mixed(self, nope_fp8, nope_scale, rope_bf16, comp_pos=None): """Add compressed KV entries (mixed FP8 nope + BF16 rope).""" T = nope_fp8.shape[0] end = self.n_comp self.comp_nope_fp8[end:end+T] = nope_fp8.view(torch.uint8) if nope_fp8.dtype != torch.uint8 else nope_fp8 self.comp_nope_scale[end:end+T] = nope_scale self.comp_rope_bf16[end:end+T] = rope_bf16 + if comp_pos is not None: + self.comp_pos_buf[end:end+T] = comp_pos self.n_comp = end + T def set_indexer_keys_fp8(self, idx_kv): """Add indexer compressed keys. idx_kv is BF16 (n_comp, ihd) or FP8 (fp8, scale).""" - T = self.n_comp # should match compressed KV count - end = T - (idx_kv[0].shape[0] if isinstance(idx_kv, tuple) else idx_kv.shape[0]) + if idx_kv is None: return + T = idx_kv[0].shape[0] if isinstance(idx_kv, tuple) else idx_kv.shape[0] + end = self.n_comp - T if isinstance(idx_kv, tuple) and len(idx_kv) == 2: fp8, scale = idx_kv - self.comp_idx_fp8[end:end+T-end] = fp8.view(torch.uint8) if fp8.dtype != torch.uint8 else fp8 - self.comp_idx_scale[end:end+T-end] = scale + self.comp_idx_fp8[end:end+T] = fp8.view(torch.uint8) if fp8.dtype != torch.uint8 else fp8 + self.comp_idx_scale[end:end+T] = scale elif isinstance(idx_kv, torch.Tensor): mod = self._get_kv_quant_mod() fp8, scale = mod.quantize_fp8_e4m3_from_fp32(idx_kv.float().contiguous()) - self.comp_idx_fp8[end:self.n_comp] = fp8.view(torch.uint8) - self.comp_idx_scale[end:self.n_comp] = scale + self.comp_idx_fp8[end:end+T] = fp8.view(torch.uint8) + self.comp_idx_scale[end:end+T] = scale self._has_idx = True def comp_nope_selective(self, indices): @@ -709,8 +712,8 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, rope_bf16 = rope_3d.squeeze(1) # (n_comp, 64) BF16 # Quantize non-RoPE part FP32 → FP8_E4M3 nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32) - # Store mixed-format compressed KV - kv_cache.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16) + # Store mixed-format compressed KV + positions + kv_cache.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16, comp_pos) if compressor.is_csa and indexer is not None and indexer.compressor is not None: comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, positions) # Indexer keys: FP8_E4M3 (ihd=128, no RoPE)