[Kernel][Model] Improve continuous batching for Jamba and Mamba (#9189)

This commit is contained in:
Mor Zusman
2024-10-17 00:12:43 +08:00
committed by GitHub
parent 415f76a9cb
commit fb60ae9b91
15 changed files with 504 additions and 432 deletions

View File

@@ -27,7 +27,8 @@ from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree)
from vllm.model_executor.models.mamba_cache import MambaCacheManager
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
@@ -110,8 +111,8 @@ class MambaMixer(nn.Module):
self.activation = config.hidden_act
def forward(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata, conv_state: torch.Tensor,
ssm_state: torch.Tensor):
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@@ -134,17 +135,18 @@ class MambaMixer(nn.Module):
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
conv_state,
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
)
conv_state_indices=mamba_cache_params.state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation
@@ -168,7 +170,7 @@ class MambaMixer(nn.Module):
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
ssm_state,
mamba_cache_params.ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
@@ -177,11 +179,12 @@ class MambaMixer(nn.Module):
gate,
time_proj_bias,
delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
ssm_state,
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
@@ -191,7 +194,7 @@ class MambaMixer(nn.Module):
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
)
state_batch_indices=mamba_cache_params.state_indices_tensor)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection
@@ -221,8 +224,7 @@ class MambaDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
mamba_cache_params: MambaCacheParams,
**kwargs,
):
if residual is None:
@@ -231,8 +233,8 @@ class MambaDecoderLayer(nn.Module):
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, attn_metadata, conv_state,
ssm_state)
hidden_states = self.mixer(hidden_states, attn_metadata,
mamba_cache_params)
return hidden_states, residual
@@ -275,25 +277,20 @@ class MambaModel(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
mamba_cache_params: MambaCacheParams,
) -> torch.Tensor:
hidden_states = self.embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
current_ssm_state = ssm_state[i]
current_conv_state = conv_state[i]
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
conv_state=current_conv_state,
ssm_state=current_ssm_state,
)
mamba_cache_params=mamba_cache_params.at_layer_idx(i))
hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states
@@ -347,12 +344,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
self.lm_head.weight.dtype, self.config.num_hidden_layers,
max_batch_size, *self._get_mamba_cache_shape())
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
input_ids, attn_metadata, **kwargs)
(
mamba_cache_tensors,
state_indices_tensor,
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
**kwargs)
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
mamba_cache_tensors[1],
state_indices_tensor)
hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_tensors[0],
mamba_cache_tensors[1])
mamba_cache_params)
return hidden_states