[Bugfix] Fix KV scales inconsistency in fp8 MLA & FlashInfer kv_cache_dtype "auto" leading to gibberish (#37054)
Signed-off-by: Andy Lo <andy@mistral.ai>
This commit is contained in:
@@ -1319,10 +1319,14 @@ class FlashInferImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
||||
self.bmm1_scale = self.scale
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
|
||||
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
self.bmm2_scale = 1.0
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
self.bmm2_scale *= layer._v_scale_float
|
||||
|
||||
prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill)
|
||||
decode_use_trtllm = isinstance(attn_metadata.decode, TRTLLMDecode)
|
||||
|
||||
@@ -255,6 +255,11 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if layer._q_scale_float != 1.0 or layer._k_scale_float != 1.0:
|
||||
raise NotImplementedError(
|
||||
"CutlassMLAImpl does not support scaling for q and kv_latent yet"
|
||||
)
|
||||
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
|
||||
@@ -177,9 +177,14 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1])
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
||||
self.bmm1_scale = self.scale
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
|
||||
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
self.bmm2_scale = 1.0
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
self.bmm2_scale *= layer._k_scale_float
|
||||
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
|
||||
@@ -340,9 +340,13 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata
|
||||
self._workspace_buffer = _get_workspace_buffer(q.device)
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
||||
self.bmm1_scale = self.scale
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
self.bmm2_scale = 1.0
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
self.bmm2_scale *= layer._k_scale_float
|
||||
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q.unsqueeze(1),
|
||||
|
||||
@@ -187,7 +187,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
self.scale,
|
||||
PAGE_SIZE,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
v_scale=layer._k_scale,
|
||||
)
|
||||
|
||||
return o, lse
|
||||
|
||||
Reference in New Issue
Block a user