[MoE][Perf] Wrap DSV3 QKVAProj GEMM in custom op for torch.compile (#35751)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2026-03-02 18:03:49 -05:00
committed by GitHub
parent c42dc402c1
commit 9319044ee9

View File

@@ -75,6 +75,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend,
@@ -717,6 +718,44 @@ class Indexer(nn.Module):
return self.indexer_op(hidden_states, q_fp8, k, weights)
def _min_latency_fused_qkv_a_proj_impl(
input_: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
"""
Dynamically run min-latency gemm if num_tokens <= 16.
This must be wrapped in a custom op because our torch.compile integration
does not support runtime dispatching on num_tokens.
"""
num_tokens = input_.shape[0]
if 0 < num_tokens <= 16:
output = torch.empty(
num_tokens,
weight.shape[0],
dtype=torch.bfloat16,
device=input_.device,
)
ops.dsv3_fused_a_gemm(output, input_, weight.T)
return output
else:
return torch.nn.functional.linear(input_, weight)
def _min_latency_fused_qkv_a_proj_fake(
input_: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return input_.new_empty(input_.shape[0], weight.shape[0])
direct_register_custom_op(
op_name="min_latency_fused_qkv_a_proj",
op_func=_min_latency_fused_qkv_a_proj_impl,
mutates_args=[],
fake_impl=_min_latency_fused_qkv_a_proj_fake,
)
class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
def __init__(
self,
@@ -752,19 +791,8 @@ class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
self,
input_,
) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]:
num_tokens = input_.shape[0]
if self._use_min_latency_gemm and (0 < num_tokens <= 16):
output = torch.empty(
num_tokens,
2112,
dtype=torch.bfloat16,
device=input_.device,
)
ops.dsv3_fused_a_gemm(
output,
input_,
self.weight.T,
)
if self._use_min_latency_gemm:
output = torch.ops.vllm.min_latency_fused_qkv_a_proj(input_, self.weight)
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None