[V1] Remove V0 code paths for Hybrid models (#25400)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -10,8 +10,6 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
@@ -28,7 +26,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
@@ -149,16 +146,12 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
has_weight=rms_norm_has_weight,
|
||||
) if use_rms_norm else None
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The outer list is for v0 PP virtual engine. Though this code path
|
||||
# only runs for v1, we have to do this to unify with the interface
|
||||
# of Attention + v0 PP.
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
@@ -186,29 +179,18 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||
return discrete_time_step, B, C
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
if not envs.VLLM_USE_V1:
|
||||
CustomOp.forward(self, hidden_states, output, mamba_cache_params)
|
||||
else:
|
||||
torch.ops.vllm.mamba_mixer(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
)
|
||||
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||
torch.ops.vllm.mamba_mixer(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def forward_native(self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
def forward_native(self, hidden_states: torch.Tensor,
|
||||
output: torch.Tensor):
|
||||
pass
|
||||
|
||||
def forward_cuda(self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||
"""
|
||||
Run the Mamba-1 SSM pipeline.
|
||||
|
||||
@@ -234,31 +216,18 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
mamba1_metadata = attn_metadata
|
||||
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
|
||||
query_start_loc = mamba1_metadata.query_start_loc
|
||||
state_indices_tensor = mamba1_metadata.state_indices_tensor
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states = mamba1_metadata.has_initial_states
|
||||
num_padded_decodes = mamba1_metadata.num_padded_decodes
|
||||
else:
|
||||
assert isinstance(attn_metadata, AttentionMetadata)
|
||||
assert mamba_cache_params is not None
|
||||
conv_state = mamba_cache_params.conv_state
|
||||
ssm_state = mamba_cache_params.ssm_state
|
||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||
query_start_loc = attn_metadata.query_start_loc
|
||||
context_lens_tensor = attn_metadata.context_lens_tensor
|
||||
has_initial_states = None
|
||||
if context_lens_tensor is not None:
|
||||
has_initial_states = context_lens_tensor > 0
|
||||
num_padded_decodes = attn_metadata.num_decode_tokens
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
mamba1_metadata = attn_metadata
|
||||
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
|
||||
query_start_loc = mamba1_metadata.query_start_loc
|
||||
state_indices_tensor = mamba1_metadata.state_indices_tensor
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states = mamba1_metadata.has_initial_states
|
||||
num_padded_decodes = mamba1_metadata.num_padded_decodes
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
@@ -267,7 +236,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
if envs.VLLM_USE_V1 and attn_metadata is None:
|
||||
if attn_metadata is None:
|
||||
# V1 profile run
|
||||
hidden_states_BC = hidden_states_BC.contiguous()
|
||||
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
|
||||
@@ -368,10 +337,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
out=scan_outputs_d)
|
||||
scan_outputs_d = scan_outputs_d.transpose(0, 1)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
ssm_outputs.insert(0, scan_outputs_d)
|
||||
else:
|
||||
ssm_outputs.append(scan_outputs_d)
|
||||
ssm_outputs.insert(0, scan_outputs_d)
|
||||
|
||||
scan_outputs_combined = ssm_outputs[0] if len(
|
||||
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
|
||||
@@ -441,40 +407,27 @@ def split_batch_to_prefill_and_decode(
|
||||
num_decodes: int,
|
||||
num_padded_decodes: int,
|
||||
) -> PrefillDecodeSplit:
|
||||
|
||||
num_actual_tokens = num_prefill_tokens + num_padded_decodes
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
# In v1, decode tokens come first, then prefill tokens.
|
||||
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||
hidden_states_BC[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
gate_d, gate_p = torch.split(gate[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
# In v1, decode tokens come first, then prefill tokens.
|
||||
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||
hidden_states_BC[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
gate_d, gate_p = torch.split(gate[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
|
||||
# num_padded_decodes accounts for CUDA graph padding when applicable
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[:num_padded_decodes + num_prefills],
|
||||
[num_padded_decodes, num_prefills],
|
||||
dim=0)
|
||||
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
|
||||
num_padded_decodes if num_prefills > 0 else None)
|
||||
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
||||
has_initial_states is not None and num_prefills > 0) else None
|
||||
else:
|
||||
# In v0, prefill tokens come first, then decode tokens.
|
||||
hidden_states_BC_p, hidden_states_BC_d = torch.split(
|
||||
hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1)
|
||||
gate_p, gate_d = torch.split(gate,
|
||||
[num_prefill_tokens, num_decode_tokens],
|
||||
dim=-1)
|
||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||
state_indices_tensor, [num_prefills, num_decodes], dim=0)
|
||||
query_start_loc_p = (query_start_loc[:num_prefills +
|
||||
1] if num_prefills > 0 else None)
|
||||
has_initial_states_p = has_initial_states[:num_prefills] if (
|
||||
has_initial_states is not None and num_prefills > 0) else None
|
||||
# num_padded_decodes accounts for CUDA graph padding when applicable
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[:num_padded_decodes + num_prefills],
|
||||
[num_padded_decodes, num_prefills],
|
||||
dim=0)
|
||||
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
|
||||
num_padded_decodes if num_prefills > 0 else None)
|
||||
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
||||
has_initial_states is not None and num_prefills > 0) else None
|
||||
|
||||
return PrefillDecodeSplit(
|
||||
hidden_states_BC_p=hidden_states_BC_p,
|
||||
@@ -495,9 +448,7 @@ def mamba_mixer(
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states,
|
||||
output=output,
|
||||
mamba_cache_params=None)
|
||||
self.forward_cuda(hidden_states=hidden_states, output=output)
|
||||
|
||||
|
||||
def mamba_mixer_fake(
|
||||
|
||||
Reference in New Issue
Block a user