[MoE][Perf] Wrap DSV3 QKVAProj GEMM in custom op for torch.compile (#35751)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
@@ -75,6 +75,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backends.mla.indexer import (
|
||||
DeepseekV32IndexerBackend,
|
||||
@@ -717,6 +718,44 @@ class Indexer(nn.Module):
|
||||
return self.indexer_op(hidden_states, q_fp8, k, weights)
|
||||
|
||||
|
||||
def _min_latency_fused_qkv_a_proj_impl(
|
||||
input_: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Dynamically run min-latency gemm if num_tokens <= 16.
|
||||
This must be wrapped in a custom op because our torch.compile integration
|
||||
does not support runtime dispatching on num_tokens.
|
||||
"""
|
||||
num_tokens = input_.shape[0]
|
||||
if 0 < num_tokens <= 16:
|
||||
output = torch.empty(
|
||||
num_tokens,
|
||||
weight.shape[0],
|
||||
dtype=torch.bfloat16,
|
||||
device=input_.device,
|
||||
)
|
||||
ops.dsv3_fused_a_gemm(output, input_, weight.T)
|
||||
return output
|
||||
else:
|
||||
return torch.nn.functional.linear(input_, weight)
|
||||
|
||||
|
||||
def _min_latency_fused_qkv_a_proj_fake(
|
||||
input_: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return input_.new_empty(input_.shape[0], weight.shape[0])
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="min_latency_fused_qkv_a_proj",
|
||||
op_func=_min_latency_fused_qkv_a_proj_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_min_latency_fused_qkv_a_proj_fake,
|
||||
)
|
||||
|
||||
|
||||
class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -752,19 +791,8 @@ class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
|
||||
self,
|
||||
input_,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]:
|
||||
num_tokens = input_.shape[0]
|
||||
if self._use_min_latency_gemm and (0 < num_tokens <= 16):
|
||||
output = torch.empty(
|
||||
num_tokens,
|
||||
2112,
|
||||
dtype=torch.bfloat16,
|
||||
device=input_.device,
|
||||
)
|
||||
ops.dsv3_fused_a_gemm(
|
||||
output,
|
||||
input_,
|
||||
self.weight.T,
|
||||
)
|
||||
if self._use_min_latency_gemm:
|
||||
output = torch.ops.vllm.min_latency_fused_qkv_a_proj(input_, self.weight)
|
||||
if not self.return_bias:
|
||||
return output
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
Reference in New Issue
Block a user