Revert "[PERF] Decouple projections from GDN custom op" (#28080)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Vadim Gimpelson
2025-11-05 03:58:23 +04:00
committed by GitHub
parent 2d977a7a9e
commit d4e547bb7e
3 changed files with 53 additions and 204 deletions

View File

@@ -30,14 +30,12 @@ from vllm.distributed import (
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.fla.ops import (
RMSNormGated,
chunk_gated_delta_rule,
fused_recurrent_gated_delta_rule,
)
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3NextRMSNorm,
)
from vllm.model_executor.layers.layernorm import RMSNormGated
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as Qwen3NextRMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@@ -438,66 +436,17 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
hidden_states: torch.Tensor,
output: torch.Tensor,
):
"""
Forward pass with three parts:
1. Input projection
2. Core attention (custom op)
3. Output projection
"""
num_tokens = hidden_states.size(0)
# ============================================================
# Part 1: Input Projection
# ============================================================
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
projected_states_ba, _ = self.in_proj_ba(hidden_states)
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba
)
query, key, value = map(
lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)
)
mixed_qkv = torch.cat((query, key, value), dim=-1)
# ============================================================
# Part 2: Core Attention (Custom Op)
# ============================================================
core_attn_out = torch.zeros(
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops.vllm.gdn_attention_core(
mixed_qkv,
b,
a,
core_attn_out,
return torch.ops.vllm.gdn_attention(
hidden_states,
output,
self.prefix,
)
# ============================================================
# Part 3: Output Projection
# ============================================================
z_shape_og = z.shape
# Reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out)
def _forward_core(
def _forward(
self,
mixed_qkv: torch.Tensor,
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
hidden_states: torch.Tensor,
output: torch.Tensor,
):
"""
Core attention computation (called by custom op).
"""
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
@@ -522,11 +471,18 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens
mixed_qkv = mixed_qkv[:num_actual_tokens]
b = b[:num_actual_tokens]
a = a[:num_actual_tokens]
# 1. Set up dimensions for reshapes later
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states[:num_actual_tokens])
projected_states_ba, _ = self.in_proj_ba(hidden_states[:num_actual_tokens])
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba
)
query, key, value = map(
lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)
)
mixed_qkv = torch.cat((query, key, value), dim=-1)
# 1. Convolution sequence transformation
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
@@ -542,7 +498,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
mixed_qkv_spec = None
mixed_qkv_non_spec = mixed_qkv
# 1.1: Process the multi-query part
# 2.1: process the mutli-query part
if spec_sequence_masks is not None:
mixed_qkv_spec = causal_conv1d_update(
mixed_qkv_spec,
@@ -559,7 +515,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
validate_data=False,
)
# 1.2: Process the remaining part
# 2.2: process the remaining part
if attn_metadata.num_prefills > 0:
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
# - "cache_indices" updates the conv_state cache in positions
@@ -617,9 +573,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
g_non_spec = g
beta_non_spec = beta
# 2. Recurrent attention
# 3. Recurrent attention
# 2.1: Process the multi-query part
# 3.1: process the mutlti-query part
if spec_sequence_masks is not None:
core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
q=query_spec,
@@ -637,7 +593,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
else:
core_attn_out_spec, last_recurrent_state = None, None
# 2.2: Process the remaining part
# 3.2: process the remaining part
if attn_metadata.num_prefills > 0:
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
initial_state[~has_initial_state, ...] = 0
@@ -680,20 +636,30 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
else:
core_attn_out_non_spec, last_recurrent_state = None, None
# 3. Merge core attention output
# Merge core attention output
if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
merged_out = torch.empty(
core_attn_out = torch.empty(
(1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
dtype=core_attn_out_non_spec.dtype,
device=core_attn_out_non_spec.device,
)
merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
core_attn_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
core_attn_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
elif spec_sequence_masks is not None:
core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
core_attn_out = core_attn_out_spec
else:
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
core_attn_out = core_attn_out_non_spec
z_shape_og = z.shape
# reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_actual_tokens], _ = self.out_proj(core_attn_out)
class Qwen3NextAttention(nn.Module):
@@ -1304,44 +1270,29 @@ class Qwen3NextForCausalLM(
return self.model.get_expert_mapping()
def gdn_attention_core(
mixed_qkv: torch.Tensor,
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
def gdn_attention(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
"""
Custom op for the core attention computation.
Only handles the convolution + recurrent attention part.
Input/output projections are handled outside this op.
"""
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._forward_core(
mixed_qkv=mixed_qkv,
b=b,
a=a,
core_attn_out=core_attn_out,
)
self._forward(hidden_states=hidden_states, output=output)
def gdn_attention_core_fake(
mixed_qkv: torch.Tensor,
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
def gdn_attention_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
"""Fake implementation for torch.compile."""
return
direct_register_custom_op(
op_name="gdn_attention_core",
op_func=gdn_attention_core,
mutates_args=["core_attn_out"],
fake_impl=gdn_attention_core_fake,
op_name="gdn_attention",
op_func=gdn_attention,
mutates_args=["output"],
fake_impl=gdn_attention_fake,
)