diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index c3e1ddb7d..5dd883f22 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -75,6 +75,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backend import AttentionBackend from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerBackend, @@ -717,6 +718,44 @@ class Indexer(nn.Module): return self.indexer_op(hidden_states, q_fp8, k, weights) +def _min_latency_fused_qkv_a_proj_impl( + input_: torch.Tensor, + weight: torch.Tensor, +) -> torch.Tensor: + """ + Dynamically run min-latency gemm if num_tokens <= 16. + This must be wrapped in a custom op because our torch.compile integration + does not support runtime dispatching on num_tokens. + """ + num_tokens = input_.shape[0] + if 0 < num_tokens <= 16: + output = torch.empty( + num_tokens, + weight.shape[0], + dtype=torch.bfloat16, + device=input_.device, + ) + ops.dsv3_fused_a_gemm(output, input_, weight.T) + return output + else: + return torch.nn.functional.linear(input_, weight) + + +def _min_latency_fused_qkv_a_proj_fake( + input_: torch.Tensor, + weight: torch.Tensor, +) -> torch.Tensor: + return input_.new_empty(input_.shape[0], weight.shape[0]) + + +direct_register_custom_op( + op_name="min_latency_fused_qkv_a_proj", + op_func=_min_latency_fused_qkv_a_proj_impl, + mutates_args=[], + fake_impl=_min_latency_fused_qkv_a_proj_fake, +) + + class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear): def __init__( self, @@ -752,19 +791,8 @@ class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear): self, input_, ) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]: - num_tokens = input_.shape[0] - if self._use_min_latency_gemm and (0 < num_tokens <= 16): - output = torch.empty( - num_tokens, - 2112, - dtype=torch.bfloat16, - device=input_.device, - ) - ops.dsv3_fused_a_gemm( - output, - input_, - self.weight.T, - ) + if self._use_min_latency_gemm: + output = torch.ops.vllm.min_latency_fused_qkv_a_proj(input_, self.weight) if not self.return_bias: return output output_bias = self.bias if self.skip_bias_add else None