[Perf] Enable dual stream execution of input projection for Qwen3 (#36795)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -180,12 +180,16 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
|
||||
# ============================================================
|
||||
# Part 1: Input Projection
|
||||
# ============================================================
|
||||
mixed_qkvz, _ = self.in_proj_qkvz(hidden_states)
|
||||
mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj(
|
||||
hidden_states,
|
||||
self.in_proj_qkvz.weight.shape[0],
|
||||
self.in_proj_ba.weight.shape[0],
|
||||
self.prefix,
|
||||
)
|
||||
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
|
||||
z_size = self.value_dim // self.tp_size
|
||||
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
|
||||
z = z.reshape(z.size(0), -1, self.head_v_dim)
|
||||
ba, _ = self.in_proj_ba(hidden_states)
|
||||
b, a = ba.chunk(2, dim=-1)
|
||||
|
||||
b = b.contiguous()
|
||||
|
||||
@@ -82,7 +82,11 @@ from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import Qwen3NextConfig
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.utils.multi_stream_utils import maybe_execute_in_parallel
|
||||
from vllm.utils.torch_utils import (
|
||||
aux_stream,
|
||||
direct_register_custom_op,
|
||||
)
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||
|
||||
@@ -419,6 +423,12 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
self.act = ACT2FN[config.hidden_act]
|
||||
self.layer_norm_epsilon = config.rms_norm_eps
|
||||
self.prefix = prefix
|
||||
self.aux_stream = aux_stream()
|
||||
self.events = (
|
||||
[torch.cuda.Event(), torch.cuda.Event()]
|
||||
if current_platform.is_cuda()
|
||||
else [None, None]
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.model_config = model_config
|
||||
@@ -647,8 +657,12 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
# ============================================================
|
||||
# Part 1: Input Projection
|
||||
# ============================================================
|
||||
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
|
||||
projected_states_ba, _ = self.in_proj_ba(hidden_states)
|
||||
projected_states_qkvz, projected_states_ba = torch.ops.vllm.gdn_in_proj(
|
||||
hidden_states,
|
||||
self.in_proj_qkvz.weight.shape[0],
|
||||
self.in_proj_ba.weight.shape[0],
|
||||
self.prefix,
|
||||
)
|
||||
query, key, value, z, b, a = self.fix_query_key_value_ordering(
|
||||
projected_states_qkvz, projected_states_ba
|
||||
)
|
||||
@@ -783,6 +797,18 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
|
||||
torch.accelerator.empty_cache()
|
||||
|
||||
def _forward_in_proj(
|
||||
self, hidden_states: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
projected_states_qkvz, projected_states_ba = maybe_execute_in_parallel(
|
||||
lambda: self.in_proj_qkvz(hidden_states)[0],
|
||||
lambda: self.in_proj_ba(hidden_states)[0],
|
||||
self.events[0],
|
||||
self.events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
return projected_states_qkvz, projected_states_ba
|
||||
|
||||
def _forward_core(
|
||||
self,
|
||||
mixed_qkv: torch.Tensor,
|
||||
@@ -1670,6 +1696,32 @@ class Qwen3NextForCausalLM(
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
|
||||
def gdn_in_proj(
|
||||
hidden_states: torch.Tensor,
|
||||
qkvz_output_size: int,
|
||||
ba_output_size: int,
|
||||
layer_name: str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Custom op for the input projection.
|
||||
"""
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
return self._forward_in_proj(hidden_states)
|
||||
|
||||
|
||||
def gdn_in_proj_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
qkvz_output_size: int,
|
||||
ba_output_size: int,
|
||||
layer_name: str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Fake implementation for torch.compile."""
|
||||
return hidden_states.new_empty(
|
||||
hidden_states.shape[0], qkvz_output_size
|
||||
), hidden_states.new_empty(hidden_states.shape[0], ba_output_size)
|
||||
|
||||
|
||||
def gdn_attention_core(
|
||||
mixed_qkv: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
@@ -1703,6 +1755,12 @@ def gdn_attention_core_fake(
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="gdn_in_proj",
|
||||
op_func=gdn_in_proj,
|
||||
fake_impl=gdn_in_proj_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="gdn_attention_core",
|
||||
op_func=gdn_attention_core,
|
||||
|
||||
48
vllm/utils/multi_stream_utils.py
Normal file
48
vllm/utils/multi_stream_utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def maybe_execute_in_parallel(
|
||||
fn0: Callable[[], Any],
|
||||
fn1: Callable[[], Any],
|
||||
event0: torch.cuda.Event,
|
||||
event1: torch.cuda.Event,
|
||||
aux_stream: torch.cuda.Stream | None = None,
|
||||
) -> tuple[Any, Any]:
|
||||
"""Run two functions potentially in parallel on separate CUDA streams.
|
||||
|
||||
When aux_stream is provided, fn0 runs on the current (default) stream and
|
||||
fn1 runs on aux_stream, synchronized via CUDA events. When aux_stream is
|
||||
None, both functions execute sequentially on the current stream.
|
||||
|
||||
This design follows TensorRT-LLM's maybe_execute_in_parallel pattern
|
||||
(tensorrt_llm/_torch/modules/multi_stream_utils.py).
|
||||
|
||||
Args:
|
||||
fn0: Callable for the default stream.
|
||||
fn1: Callable for the auxiliary stream.
|
||||
event0: CUDA event recorded before fn0 so aux_stream can wait.
|
||||
event1: CUDA event recorded after fn1 so default stream can wait.
|
||||
aux_stream: The second CUDA stream for fn1.
|
||||
Multi-stream is disabled when aux_stream is None.
|
||||
|
||||
Returns:
|
||||
Tuple of (fn0_result, fn1_result).
|
||||
"""
|
||||
if aux_stream is not None:
|
||||
event0.record()
|
||||
result0 = fn0()
|
||||
with torch.cuda.stream(aux_stream):
|
||||
event0.wait()
|
||||
result1 = fn1()
|
||||
event1.record()
|
||||
event1.wait()
|
||||
else:
|
||||
result0 = fn0()
|
||||
result1 = fn1()
|
||||
return (result0, result1)
|
||||
Reference in New Issue
Block a user