diff --git a/vllm/patches/layers/csa_attention.py b/vllm/patches/layers/csa_attention.py index c315fed5..c9f2b9af 100644 --- a/vllm/patches/layers/csa_attention.py +++ b/vllm/patches/layers/csa_attention.py @@ -247,9 +247,9 @@ def blackwell_attention_decode( kv_cached_raw = swa_kv_cache[block_indices, offsets] if swa_kv_cache.dtype == torch.uint8: kv_cached_raw = kv_cached_raw.view(torch.float8_e4m3fn) - # Dequantize: for now use bf16 cast (fp8 → bf16 without per-token scale) - # TODO: store and read per-token inv_scale in paged cache - kv_cached = kv_cached_raw.to(torch.bfloat16) + # Dequantize with per-token inverse scale + inv_scales = swa_inv_scale[indices] + kv_cached = kv_dequantize_fp8(kv_cached_raw, inv_scales) else: # Fallback: sequential slot access pos = positions[0].item()