[Attention][FlashInfer] Enable FP8 FlashInfer (TRTLLM) MLA decode (#24705)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
@@ -88,6 +88,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
"Setting it to k_scale. This only matters for "
|
||||
"the flash-attn backend.")
|
||||
layer._q_scale.copy_(k_scale)
|
||||
layer._q_scale_float = k_scale
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._k_scale.copy_(k_scale)
|
||||
@@ -124,6 +125,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._q_scale.copy_(q_scale)
|
||||
layer._q_scale_float = q_scale
|
||||
layer._prob_scale.copy_(prob_scale)
|
||||
if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0
|
||||
or prob_scale == 1.0):
|
||||
|
||||
Reference in New Issue
Block a user