[Kernel][Model] Improve continuous batching for Jamba and Mamba (#9189)
This commit is contained in:
@@ -27,7 +27,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||
IsAttentionFree)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheManager
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -110,8 +111,8 @@ class MambaMixer(nn.Module):
|
||||
self.activation = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata, conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor):
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams):
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
@@ -134,17 +135,18 @@ class MambaMixer(nn.Module):
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
conv_states=mamba_cache_params.conv_state,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
hidden_states = causal_conv1d_update(
|
||||
hidden_states.transpose(0, 1),
|
||||
conv_state,
|
||||
mamba_cache_params.conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
conv_state_indices=mamba_cache_params.state_indices_tensor)
|
||||
hidden_states = hidden_states.transpose(0, 1)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
@@ -168,7 +170,7 @@ class MambaMixer(nn.Module):
|
||||
and attn_metadata.context_lens_tensor is not None:
|
||||
scan_outputs = selective_scan_fn(
|
||||
hidden_states,
|
||||
ssm_state,
|
||||
mamba_cache_params.ssm_state,
|
||||
discrete_time_step,
|
||||
self.A,
|
||||
B.transpose(-2, -1),
|
||||
@@ -177,11 +179,12 @@ class MambaMixer(nn.Module):
|
||||
gate,
|
||||
time_proj_bias,
|
||||
delta_softplus=True,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
scan_outputs = selective_state_update(
|
||||
ssm_state,
|
||||
mamba_cache_params.ssm_state,
|
||||
hidden_states.transpose(0, 1),
|
||||
discrete_time_step.transpose(0, 1),
|
||||
self.A,
|
||||
@@ -191,7 +194,7 @@ class MambaMixer(nn.Module):
|
||||
gate.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor)
|
||||
scan_outputs = scan_outputs.transpose(0, 1)
|
||||
|
||||
# 4. Final linear projection
|
||||
@@ -221,8 +224,7 @@ class MambaDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
@@ -231,8 +233,8 @@ class MambaDecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer(hidden_states, attn_metadata, conv_state,
|
||||
ssm_state)
|
||||
hidden_states = self.mixer(hidden_states, attn_metadata,
|
||||
mamba_cache_params)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@@ -275,25 +277,20 @@ class MambaModel(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = self.embeddings(input_ids)
|
||||
residual = None
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
current_ssm_state = ssm_state[i]
|
||||
current_conv_state = conv_state[i]
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
conv_state=current_conv_state,
|
||||
ssm_state=current_ssm_state,
|
||||
)
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(i))
|
||||
hidden_states, _ = self.norm_f(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
@@ -347,12 +344,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
self.lm_head.weight.dtype, self.config.num_hidden_layers,
|
||||
max_batch_size, *self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
|
||||
input_ids, attn_metadata, **kwargs)
|
||||
(
|
||||
mamba_cache_tensors,
|
||||
state_indices_tensor,
|
||||
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
|
||||
**kwargs)
|
||||
|
||||
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
|
||||
mamba_cache_tensors[1],
|
||||
state_indices_tensor)
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, attn_metadata,
|
||||
mamba_cache_tensors[0],
|
||||
mamba_cache_tensors[1])
|
||||
mamba_cache_params)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user