[Attention][FlashInfer] Enable FP8 FlashInfer (TRTLLM) MLA decode (#24705)

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
Matthew Bonanni
2025-09-12 17:45:53 -04:00
committed by GitHub
parent c89ed8de43
commit 7ba32aa60b
8 changed files with 23 additions and 10 deletions

View File

@@ -88,6 +88,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
"Setting it to k_scale. This only matters for "
"the flash-attn backend.")
layer._q_scale.copy_(k_scale)
layer._q_scale_float = k_scale
# These are used in the final Attention.forward()
layer._k_scale.copy_(k_scale)
@@ -124,6 +125,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
# These are used in the final Attention.forward()
layer._q_scale.copy_(q_scale)
layer._q_scale_float = q_scale
layer._prob_scale.copy_(prob_scale)
if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0
or prob_scale == 1.0):