[Model Bash] DeepSeek R1 BF16 Min Latency QKV A GEMM (0.5% E2E Speedup) (#34758)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2026-02-18 10:42:36 -05:00
committed by GitHub
parent e24663c5a9
commit 6874638bc4
7 changed files with 855 additions and 3 deletions

View File

@@ -32,6 +32,7 @@ import torch
from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config
import vllm._custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config
@@ -711,6 +712,64 @@ class Indexer(nn.Module):
return self.indexer_op(hidden_states, q_fp8, k, weights)
class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
def __init__(
self,
input_size: int,
output_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__(
input_size,
output_size,
bias=False,
quant_config=quant_config,
disable_tp=True,
prefix=f"{prefix}.kv_a_proj_with_mqa",
)
# Check if the DeepSeek V3 fused A GEMM kernel can be used.
# This kernel supports PDL and is optimized for low batch size.
self._use_min_latency_gemm = (
hasattr(self, "weight")
and self.weight.dtype == torch.bfloat16
and self.weight.shape[0] == 2112
and self.weight.shape[1] == 7168
and current_platform.is_cuda()
and (
current_platform.is_device_capability(90)
or current_platform.is_device_capability_family(100)
)
)
def forward(
self,
input_,
) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]:
num_tokens = input_.shape[0]
if self._use_min_latency_gemm and (0 < num_tokens <= 16):
output = torch.empty(
num_tokens,
2112,
dtype=torch.bfloat16,
device=input_.device,
)
ops.dsv3_fused_a_gemm(
output,
input_,
self.weight.T,
)
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
else:
# Fallback to the standard forward method when
# the fused A GEMM kernel cannot be used.
return super().forward(input_)
class DeepseekV2MLAAttention(nn.Module):
"""
Main reference: DeepseekV2 paper, and FlashInfer Implementation
@@ -756,13 +815,11 @@ class DeepseekV2MLAAttention(nn.Module):
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
self.fused_qkv_a_proj = MergedColumnParallelLinear(
self.fused_qkv_a_proj = DeepSeekV2FusedQkvAProj(
self.hidden_size,
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_qkv_a_proj",
disable_tp=True,
)
else:
self.kv_a_proj_with_mqa = ReplicatedLinear(