[V1] Remove V0 code paths for Hybrid models (#25400)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2025-09-23 17:26:13 +02:00
committed by GitHub
parent 2c58742dff
commit a903669e10
31 changed files with 352 additions and 2296 deletions

View File

@@ -9,7 +9,6 @@ if TYPE_CHECKING:
import torch
from torch import nn
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
@@ -22,8 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
@@ -36,7 +33,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import (
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
@@ -449,16 +445,12 @@ class MambaMixer2(MambaBase, CustomOp):
self.use_rms_norm,
eps=rms_norm_eps)
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The tuple is (conv_state, ssm_state)
self.kv_cache = (torch.tensor([]), torch.tensor([]))
self.model_config = model_config
self.cache_config = cache_config
@@ -468,8 +460,6 @@ class MambaMixer2(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
):
pass
@@ -478,59 +468,43 @@ class MambaMixer2(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
):
if not envs.VLLM_USE_V1:
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
mamba2_metadata, mup_vector)
else:
torch.ops.vllm.mamba_mixer2(
hidden_states,
output,
self.prefix,
mup_vector,
)
torch.ops.vllm.mamba_mixer2(
hidden_states,
output,
self.prefix,
mup_vector,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
):
forward_context = get_forward_context()
# mamba2_metadata contains metadata necessary for the mamba2 triton
# attn_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
mamba2_metadata = attn_metadata
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
else:
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
# Common members between V1 metadata and V0 metadata
if mamba2_metadata is not None:
has_initial_states_p = mamba2_metadata.has_initial_states_p
prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx_p
chunk_indices_p = mamba2_metadata.chunk_indices_p
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
@@ -562,8 +536,8 @@ class MambaMixer2(MambaBase, CustomOp):
dim=-1,
)
if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run
if attn_metadata is None:
# profile run
hidden_states_B_C = (hidden_states_B_C.transpose(
0, 1).clone().transpose(0, 1)).contiguous()
hidden_states, _B, _C = split_hidden_states_B_C_fn(
@@ -579,49 +553,27 @@ class MambaMixer2(MambaBase, CustomOp):
has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
if envs.VLLM_USE_V1:
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
dt_d, dt_p = torch.split(
dt[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor[:num_actual_tokens],
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
else:
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
hidden_states_B_C,
[num_prefill_tokens, num_decodes],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
1]
if has_prefill else None)
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
dt_d, dt_p = torch.split(
dt[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor[:num_actual_tokens],
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
@@ -633,18 +585,11 @@ class MambaMixer2(MambaBase, CustomOp):
dtype=hidden_states.dtype,
device=hidden_states.device,
)
if envs.VLLM_USE_V1:
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out,
[num_decodes, num_prefill_tokens],
dim=0,
)
else:
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out,
[num_decodes, num_prefill_tokens],
dim=0,
)
# Process prefill requests
if has_prefill:
@@ -653,9 +598,6 @@ class MambaMixer2(MambaBase, CustomOp):
# pointed to by "state_indices_tensor"
x = hidden_states_B_C_p.transpose(
0, 1) # this is the form that causal-conv see
if mamba2_metadata.cu_seqlen is None:
mamba2_metadata = update_metadata(x, query_start_loc_p,
mamba2_metadata)
hidden_states_B_C_p = causal_conv1d_fn(
x,
conv_weights,
@@ -664,7 +606,7 @@ class MambaMixer2(MambaBase, CustomOp):
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
metadata=mamba2_metadata,
metadata=attn_metadata,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
@@ -806,8 +748,6 @@ def mamba_mixer2(
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
mamba_cache_params=None,
mamba2_metadata=None,
mup_vector=mup_vector)