[MLA] Fuse cat and qaunt for fp8 kv-cache (#32950)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -202,13 +202,16 @@ from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
get_and_maybe_dequant_weights,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
@@ -287,6 +290,37 @@ def dynamic_per_batched_tensor_quant(
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@CustomOp.register("mla_decode_concat_quant_fp8")
|
||||
class _DecodeConcatQuantFP8(QuantFP8):
|
||||
"""
|
||||
QuantFP8 variant that concatenates decode_ql_nope and decode_q_pe before
|
||||
quantization. When disabled, forward_native is compiled via torch.compile,
|
||||
fusing cat/reshape/quant/view together.
|
||||
"""
|
||||
|
||||
def _make_forward(quant_fn): # noqa: N805
|
||||
"""Factory to create forward methods that concat before quantization."""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
decode_ql_nope: torch.Tensor,
|
||||
decode_q_pe: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
decode_q0 = torch.cat((decode_ql_nope, decode_q_pe), dim=-1)
|
||||
decode_q_flat = decode_q0.reshape(decode_q0.shape[0], -1)
|
||||
decode_q, _ = quant_fn(self, decode_q_flat, scale, scale_ub)
|
||||
return decode_q.view(decode_q0.shape)
|
||||
|
||||
return forward
|
||||
|
||||
forward_native = _make_forward(QuantFP8.forward_native) # type: ignore[arg-type]
|
||||
forward_cuda = _make_forward(QuantFP8.forward_cuda) # type: ignore[arg-type]
|
||||
forward_hip = _make_forward(QuantFP8.forward_hip) # type: ignore[arg-type]
|
||||
|
||||
|
||||
CUDNN_WORKSPACE_SIZE = 12800
|
||||
|
||||
|
||||
@@ -1398,6 +1432,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
self.cp_kv_cache_interleave_size: int = (
|
||||
get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size
|
||||
)
|
||||
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
compile_native=True,
|
||||
)
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
@@ -2048,29 +2087,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||
|
||||
if fp8_attention:
|
||||
ql_nope_shape = decode_ql_nope.shape
|
||||
q_pe_shape = decode_q_pe.shape
|
||||
assert decode_ql_nope.shape[0] == decode_q_pe.shape[0]
|
||||
assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]
|
||||
decode_q_shape = (
|
||||
ql_nope_shape[0],
|
||||
ql_nope_shape[1],
|
||||
ql_nope_shape[2] + q_pe_shape[2],
|
||||
decode_q = self._decode_concat_quant_fp8_op(
|
||||
decode_ql_nope, decode_q_pe, layer._q_scale
|
||||
)
|
||||
# Using empty and copy since torch.cat introduces significant overhead.
|
||||
decode_q0 = torch.empty(
|
||||
decode_q_shape,
|
||||
device=decode_ql_nope.device,
|
||||
dtype=decode_ql_nope.dtype,
|
||||
)
|
||||
decode_q0[..., : ql_nope_shape[2]].copy_(decode_ql_nope)
|
||||
decode_q0[..., ql_nope_shape[2] :].copy_(decode_q_pe)
|
||||
|
||||
decode_q, _ = ops.scaled_fp8_quant(
|
||||
decode_q0.view(decode_q_shape[0], -1),
|
||||
layer._q_scale,
|
||||
)
|
||||
decode_q = decode_q.view(decode_q_shape)
|
||||
else:
|
||||
decode_q = (decode_ql_nope, decode_q_pe)
|
||||
if self.dcp_world_size > 1:
|
||||
|
||||
Reference in New Issue
Block a user