[MLA] Fuse cat and qaunt for fp8 kv-cache (#32950)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-01-24 09:03:02 -07:00
committed by GitHub
parent 719ac592ed
commit da5e7b12be

View File

@@ -202,13 +202,16 @@ from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_and_maybe_dequant_weights,
)
from vllm.platforms import current_platform
@@ -287,6 +290,37 @@ def dynamic_per_batched_tensor_quant(
logger = init_logger(__name__)
@CustomOp.register("mla_decode_concat_quant_fp8")
class _DecodeConcatQuantFP8(QuantFP8):
"""
QuantFP8 variant that concatenates decode_ql_nope and decode_q_pe before
quantization. When disabled, forward_native is compiled via torch.compile,
fusing cat/reshape/quant/view together.
"""
def _make_forward(quant_fn): # noqa: N805
"""Factory to create forward methods that concat before quantization."""
def forward(
self,
decode_ql_nope: torch.Tensor,
decode_q_pe: torch.Tensor,
scale: torch.Tensor,
scale_ub: torch.Tensor | None = None,
) -> torch.Tensor:
decode_q0 = torch.cat((decode_ql_nope, decode_q_pe), dim=-1)
decode_q_flat = decode_q0.reshape(decode_q0.shape[0], -1)
decode_q, _ = quant_fn(self, decode_q_flat, scale, scale_ub)
return decode_q.view(decode_q0.shape)
return forward
forward_native = _make_forward(QuantFP8.forward_native) # type: ignore[arg-type]
forward_cuda = _make_forward(QuantFP8.forward_cuda) # type: ignore[arg-type]
forward_hip = _make_forward(QuantFP8.forward_hip) # type: ignore[arg-type]
CUDNN_WORKSPACE_SIZE = 12800
@@ -1398,6 +1432,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
self.cp_kv_cache_interleave_size: int = (
get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size
)
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
compile_native=True,
)
def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
@@ -2048,29 +2087,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
decode_ql_nope = decode_ql_nope.transpose(0, 1)
if fp8_attention:
ql_nope_shape = decode_ql_nope.shape
q_pe_shape = decode_q_pe.shape
assert decode_ql_nope.shape[0] == decode_q_pe.shape[0]
assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]
decode_q_shape = (
ql_nope_shape[0],
ql_nope_shape[1],
ql_nope_shape[2] + q_pe_shape[2],
decode_q = self._decode_concat_quant_fp8_op(
decode_ql_nope, decode_q_pe, layer._q_scale
)
# Using empty and copy since torch.cat introduces significant overhead.
decode_q0 = torch.empty(
decode_q_shape,
device=decode_ql_nope.device,
dtype=decode_ql_nope.dtype,
)
decode_q0[..., : ql_nope_shape[2]].copy_(decode_ql_nope)
decode_q0[..., ql_nope_shape[2] :].copy_(decode_q_pe)
decode_q, _ = ops.scaled_fp8_quant(
decode_q0.view(decode_q_shape[0], -1),
layer._q_scale,
)
decode_q = decode_q.view(decode_q_shape)
else:
decode_q = (decode_ql_nope, decode_q_pe)
if self.dcp_world_size > 1: