CRITICAL FIX: Properly dequantize fp8 KV in decode using per-token inv_scale
This commit is contained in:
@@ -247,9 +247,9 @@ def blackwell_attention_decode(
|
||||
kv_cached_raw = swa_kv_cache[block_indices, offsets]
|
||||
if swa_kv_cache.dtype == torch.uint8:
|
||||
kv_cached_raw = kv_cached_raw.view(torch.float8_e4m3fn)
|
||||
# Dequantize: for now use bf16 cast (fp8 → bf16 without per-token scale)
|
||||
# TODO: store and read per-token inv_scale in paged cache
|
||||
kv_cached = kv_cached_raw.to(torch.bfloat16)
|
||||
# Dequantize with per-token inverse scale
|
||||
inv_scales = swa_inv_scale[indices]
|
||||
kv_cached = kv_dequantize_fp8(kv_cached_raw, inv_scales)
|
||||
else:
|
||||
# Fallback: sequential slot access
|
||||
pos = positions[0].item()
|
||||
|
||||
Reference in New Issue
Block a user