[Model Bash] DeepSeek R1 BF16 Min Latency QKV A GEMM (0.5% E2E Speedup) (#34758)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
@@ -32,6 +32,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config
|
||||
@@ -711,6 +712,64 @@ class Indexer(nn.Module):
|
||||
return self.indexer_op(hidden_states, q_fp8, k, weights)
|
||||
|
||||
|
||||
class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(
|
||||
input_size,
|
||||
output_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
disable_tp=True,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa",
|
||||
)
|
||||
|
||||
# Check if the DeepSeek V3 fused A GEMM kernel can be used.
|
||||
# This kernel supports PDL and is optimized for low batch size.
|
||||
self._use_min_latency_gemm = (
|
||||
hasattr(self, "weight")
|
||||
and self.weight.dtype == torch.bfloat16
|
||||
and self.weight.shape[0] == 2112
|
||||
and self.weight.shape[1] == 7168
|
||||
and current_platform.is_cuda()
|
||||
and (
|
||||
current_platform.is_device_capability(90)
|
||||
or current_platform.is_device_capability_family(100)
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]:
|
||||
num_tokens = input_.shape[0]
|
||||
if self._use_min_latency_gemm and (0 < num_tokens <= 16):
|
||||
output = torch.empty(
|
||||
num_tokens,
|
||||
2112,
|
||||
dtype=torch.bfloat16,
|
||||
device=input_.device,
|
||||
)
|
||||
ops.dsv3_fused_a_gemm(
|
||||
output,
|
||||
input_,
|
||||
self.weight.T,
|
||||
)
|
||||
if not self.return_bias:
|
||||
return output
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
else:
|
||||
# Fallback to the standard forward method when
|
||||
# the fused A GEMM kernel cannot be used.
|
||||
return super().forward(input_)
|
||||
|
||||
|
||||
class DeepseekV2MLAAttention(nn.Module):
|
||||
"""
|
||||
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||
@@ -756,13 +815,11 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.fused_qkv_a_proj = MergedColumnParallelLinear(
|
||||
self.fused_qkv_a_proj = DeepSeekV2FusedQkvAProj(
|
||||
self.hidden_size,
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fused_qkv_a_proj",
|
||||
disable_tp=True,
|
||||
)
|
||||
else:
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
|
||||
Reference in New Issue
Block a user