CRITICAL FIX: Properly dequantize fp8 KV in decode using per-token inv_scale

This commit is contained in:
2026-05-19 17:08:58 +00:00
parent 2f811bc8bd
commit aade8593f7

View File

@@ -247,9 +247,9 @@ def blackwell_attention_decode(
kv_cached_raw = swa_kv_cache[block_indices, offsets]
if swa_kv_cache.dtype == torch.uint8:
kv_cached_raw = kv_cached_raw.view(torch.float8_e4m3fn)
# Dequantize: for now use bf16 cast (fp8 → bf16 without per-token scale)
# TODO: store and read per-token inv_scale in paged cache
kv_cached = kv_cached_raw.to(torch.bfloat16)
# Dequantize with per-token inverse scale
inv_scales = swa_inv_scale[indices]
kv_cached = kv_dequantize_fp8(kv_cached_raw, inv_scales)
else:
# Fallback: sequential slot access
pos = positions[0].item()