[Kernel] Triton implementation of causal-conv1d for Mamba-based models (#18218)

Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Co-authored-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tuan, Hoang-Trong
2025-07-09 15:53:55 -04:00
committed by GitHub
parent 31b96d1c64
commit 47043eb678
15 changed files with 1117 additions and 1142 deletions

View File

@@ -17,7 +17,8 @@ from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
@@ -161,9 +162,9 @@ def mamba_v2_sharded_weight_loader(
tp_size: int,
tp_rank: int,
) -> LoaderFunction:
"""Create a weight loader for mamba v2. This ensures that the projections
are correctly sharded so that they can be split into x, B, C. It also
ensures that all the groups corresponding to a head shard is placed
"""Create a weight loader for mamba v2. This ensures that the projections
are correctly sharded so that they can be split into x, B, C. It also
ensures that all the groups corresponding to a head shard is placed
together with it.
"""
@@ -458,9 +459,11 @@ class MambaMixer2(CustomOp):
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
mamba2_metadata = attn_metadata
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states
@@ -531,6 +534,7 @@ class MambaMixer2(CustomOp):
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
if envs.VLLM_USE_V1:
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C,
@@ -579,8 +583,13 @@ class MambaMixer2(CustomOp):
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "state_indices_tensor"
x = hidden_states_B_C_p.transpose(
0, 1) # this is the form that causal-conv see
if mamba2_metadata.cu_seqlen is None:
mamba2_metadata = update_metadata(
x, attn_metadata.query_start_loc, mamba2_metadata)
hidden_states_B_C_p = causal_conv1d_fn(
hidden_states_B_C_p.transpose(0, 1),
x,
conv_weights,
self.conv1d.bias,
activation=self.activation,
@@ -590,8 +599,6 @@ class MambaMixer2(CustomOp):
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
# TODO: Why is this needed?
hidden_states_B_C_p = hidden_states_B_C_p.contiguous()
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
hidden_states_B_C_p)
@@ -715,9 +722,10 @@ class MambaMixer2(CustomOp):
# - heads and n_groups are TP-ed
conv_dim = (self.intermediate_size +
2 * n_groups * self.ssm_state_size)
# contiguous along 'dim' axis
conv_state_shape = (
divide(conv_dim, world_size),
self.conv_kernel_size - 1,
divide(conv_dim, world_size),
)
# These are not TP-ed as they depend on A, dt_bias, D