[Kernel] Triton implementation of causal-conv1d for Mamba-based models (#18218)
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com> Co-authored-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
31b96d1c64
commit
47043eb678
@@ -17,7 +17,8 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
||||
update_metadata)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
@@ -161,9 +162,9 @@ def mamba_v2_sharded_weight_loader(
|
||||
tp_size: int,
|
||||
tp_rank: int,
|
||||
) -> LoaderFunction:
|
||||
"""Create a weight loader for mamba v2. This ensures that the projections
|
||||
are correctly sharded so that they can be split into x, B, C. It also
|
||||
ensures that all the groups corresponding to a head shard is placed
|
||||
"""Create a weight loader for mamba v2. This ensures that the projections
|
||||
are correctly sharded so that they can be split into x, B, C. It also
|
||||
ensures that all the groups corresponding to a head shard is placed
|
||||
together with it.
|
||||
"""
|
||||
|
||||
@@ -458,9 +459,11 @@ class MambaMixer2(CustomOp):
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
mamba2_metadata = attn_metadata
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0]
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states
|
||||
@@ -531,6 +534,7 @@ class MambaMixer2(CustomOp):
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
if envs.VLLM_USE_V1:
|
||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||
hidden_states_B_C,
|
||||
@@ -579,8 +583,13 @@ class MambaMixer2(CustomOp):
|
||||
# 2. Convolution sequence transformation
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "state_indices_tensor"
|
||||
x = hidden_states_B_C_p.transpose(
|
||||
0, 1) # this is the form that causal-conv see
|
||||
if mamba2_metadata.cu_seqlen is None:
|
||||
mamba2_metadata = update_metadata(
|
||||
x, attn_metadata.query_start_loc, mamba2_metadata)
|
||||
hidden_states_B_C_p = causal_conv1d_fn(
|
||||
hidden_states_B_C_p.transpose(0, 1),
|
||||
x,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
@@ -590,8 +599,6 @@ class MambaMixer2(CustomOp):
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
|
||||
# TODO: Why is this needed?
|
||||
hidden_states_B_C_p = hidden_states_B_C_p.contiguous()
|
||||
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C_p)
|
||||
|
||||
@@ -715,9 +722,10 @@ class MambaMixer2(CustomOp):
|
||||
# - heads and n_groups are TP-ed
|
||||
conv_dim = (self.intermediate_size +
|
||||
2 * n_groups * self.ssm_state_size)
|
||||
# contiguous along 'dim' axis
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, world_size),
|
||||
self.conv_kernel_size - 1,
|
||||
divide(conv_dim, world_size),
|
||||
)
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
|
||||
Reference in New Issue
Block a user