[Kernel][Model] Improve continuous batching for Jamba and Mamba (#9189)
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
# coding=utf-8
|
||||
"""Inference-only Jamba model."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -29,7 +28,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheManager
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -41,13 +41,6 @@ from .interfaces import HasInnerState, SupportsLoRA
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaCacheParams:
|
||||
is_prompt: bool = False
|
||||
conv_state: torch.Tensor = torch.Tensor()
|
||||
ssm_state: torch.Tensor = torch.Tensor()
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
class JambaMambaMixer(nn.Module):
|
||||
"""
|
||||
@@ -60,10 +53,9 @@ class JambaMambaMixer(nn.Module):
|
||||
**selective** state spaces)
|
||||
"""
|
||||
|
||||
def __init__(self, config: JambaConfig, layer_idx):
|
||||
def __init__(self, config: JambaConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.ssm_state_size = config.mamba_d_state
|
||||
self.conv_kernel_size = config.mamba_d_conv
|
||||
@@ -129,8 +121,8 @@ class JambaMambaMixer(nn.Module):
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata, conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor):
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams):
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
@@ -153,17 +145,18 @@ class JambaMambaMixer(nn.Module):
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
conv_states=mamba_cache_params.conv_state,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
hidden_states = causal_conv1d_update(
|
||||
hidden_states.transpose(0, 1),
|
||||
conv_state,
|
||||
mamba_cache_params.conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
conv_state_indices=mamba_cache_params.state_indices_tensor)
|
||||
hidden_states = hidden_states.transpose(0, 1)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
@@ -188,7 +181,7 @@ class JambaMambaMixer(nn.Module):
|
||||
and attn_metadata.context_lens_tensor is not None:
|
||||
scan_outputs = selective_scan_fn(
|
||||
hidden_states,
|
||||
ssm_state,
|
||||
mamba_cache_params.ssm_state,
|
||||
discrete_time_step,
|
||||
self.A,
|
||||
B.transpose(-2, -1),
|
||||
@@ -197,11 +190,12 @@ class JambaMambaMixer(nn.Module):
|
||||
gate,
|
||||
time_proj_bias,
|
||||
delta_softplus=True,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
scan_outputs = selective_state_update(
|
||||
ssm_state,
|
||||
mamba_cache_params.ssm_state,
|
||||
hidden_states.transpose(0, 1),
|
||||
discrete_time_step.transpose(0, 1),
|
||||
self.A,
|
||||
@@ -211,7 +205,7 @@ class JambaMambaMixer(nn.Module):
|
||||
gate.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor)
|
||||
scan_outputs = scan_outputs.transpose(0, 1)
|
||||
|
||||
# 4. Final linear projection
|
||||
@@ -292,7 +286,7 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.config = config
|
||||
self.mamba = JambaMambaMixer(config, layer_idx)
|
||||
self.mamba = JambaMambaMixer(config)
|
||||
|
||||
num_experts = config.layers_num_experts[layer_idx]
|
||||
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
||||
@@ -307,8 +301,7 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
@@ -318,8 +311,8 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.mamba(hidden_states, attn_metadata, conv_state,
|
||||
ssm_state)
|
||||
hidden_states = self.mamba(hidden_states, attn_metadata,
|
||||
mamba_cache_params)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(
|
||||
hidden_states, residual)
|
||||
@@ -476,17 +469,14 @@ class JambaModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
kv_cache = None
|
||||
current_ssm_state = None
|
||||
current_conv_state = None
|
||||
layer_mamba_cache_params = None
|
||||
if isinstance(layer, JambaAttentionDecoderLayer):
|
||||
kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
|
||||
self.config.attn_layer_period]
|
||||
@@ -494,8 +484,8 @@ class JambaModel(nn.Module):
|
||||
current_state_layer = i - (1 +
|
||||
(i - self.config.attn_layer_offset)
|
||||
// self.config.attn_layer_period)
|
||||
current_ssm_state = ssm_state[current_state_layer]
|
||||
current_conv_state = conv_state[current_state_layer]
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||
current_state_layer)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
@@ -503,9 +493,7 @@ class JambaModel(nn.Module):
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
conv_state=current_conv_state,
|
||||
ssm_state=current_ssm_state,
|
||||
)
|
||||
mamba_cache_params=layer_mamba_cache_params)
|
||||
hidden_states, _ = self.final_layernorm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
@@ -588,13 +576,16 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
|
||||
*self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
|
||||
input_ids, attn_metadata, **kwargs)
|
||||
|
||||
(
|
||||
mamba_cache_tensors,
|
||||
state_indices_tensor,
|
||||
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
|
||||
**kwargs)
|
||||
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
|
||||
mamba_cache_tensors[1],
|
||||
state_indices_tensor)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, mamba_cache_tensors[0],
|
||||
mamba_cache_tensors[1])
|
||||
attn_metadata, mamba_cache_params)
|
||||
return hidden_states
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user