[Bugfix] Fix DSV3.2 NVFP4 (#33932)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-02-05 14:22:19 -05:00
committed by GitHub
parent 20f5d185a6
commit 4145e50d85

View File

@@ -530,7 +530,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
scale=self._k_scale, scale=self._k_scale,
) )
if fp8_attention: if fp8_attention and self.kv_cache_dtype != "fp8_ds_mla":
kv_cache = kv_cache.view(current_platform.fp8_dtype()) kv_cache = kv_cache.view(current_platform.fp8_dtype())
# Sparse MLA impls only support forward_mqa (decode-style attention) # Sparse MLA impls only support forward_mqa (decode-style attention)
@@ -614,7 +614,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# Convert from (N, B, L) to (B, N, L) # Convert from (N, B, L) to (B, N, L)
mqa_ql_nope = mqa_ql_nope.transpose(0, 1) mqa_ql_nope = mqa_ql_nope.transpose(0, 1)
if fp8_attention: if fp8_attention and self.impl.supports_quant_query_input:
assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0] assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0]
assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1] assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1]
mqa_q = self._decode_concat_quant_fp8_op( mqa_q = self._decode_concat_quant_fp8_op(
@@ -1885,6 +1885,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.indexer = indexer self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads self.q_pad_num_heads = q_pad_num_heads
self.supports_quant_query_input = True
# Use flashinfer's optimized concat_mla_k kernel when available. # Use flashinfer's optimized concat_mla_k kernel when available.
# The kernel is optimized for DeepSeek V3 dimensions: # The kernel is optimized for DeepSeek V3 dimensions:
# num_heads=128, nope_dim=128, rope_dim=64 # num_heads=128, nope_dim=128, rope_dim=64