[Kernel][Model] Improve continuous batching for Jamba and Mamba (#9189)

This commit is contained in:
Mor Zusman
2024-10-17 00:12:43 +08:00
committed by GitHub
parent 415f76a9cb
commit fb60ae9b91
15 changed files with 504 additions and 432 deletions

View File

@@ -1,6 +1,5 @@
# coding=utf-8
"""Inference-only Jamba model."""
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple
import torch
@@ -29,7 +28,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheManager
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
@@ -41,13 +41,6 @@ from .interfaces import HasInnerState, SupportsLoRA
KVCache = Tuple[torch.Tensor, torch.Tensor]
@dataclass
class MambaCacheParams:
is_prompt: bool = False
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class JambaMambaMixer(nn.Module):
"""
@@ -60,10 +53,9 @@ class JambaMambaMixer(nn.Module):
**selective** state spaces)
"""
def __init__(self, config: JambaConfig, layer_idx):
def __init__(self, config: JambaConfig):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.ssm_state_size = config.mamba_d_state
self.conv_kernel_size = config.mamba_d_conv
@@ -129,8 +121,8 @@ class JambaMambaMixer(nn.Module):
eps=config.rms_norm_eps)
def forward(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata, conv_state: torch.Tensor,
ssm_state: torch.Tensor):
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@@ -153,17 +145,18 @@ class JambaMambaMixer(nn.Module):
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
conv_state,
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
)
conv_state_indices=mamba_cache_params.state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation
@@ -188,7 +181,7 @@ class JambaMambaMixer(nn.Module):
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
ssm_state,
mamba_cache_params.ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
@@ -197,11 +190,12 @@ class JambaMambaMixer(nn.Module):
gate,
time_proj_bias,
delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
ssm_state,
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
@@ -211,7 +205,7 @@ class JambaMambaMixer(nn.Module):
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
)
state_batch_indices=mamba_cache_params.state_indices_tensor)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection
@@ -292,7 +286,7 @@ class JambaMambaDecoderLayer(nn.Module):
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.mamba = JambaMambaMixer(config, layer_idx)
self.mamba = JambaMambaMixer(config)
num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
@@ -307,8 +301,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
mamba_cache_params: MambaCacheParams,
**kwargs,
):
if residual is None:
@@ -318,8 +311,8 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.mamba(hidden_states, attn_metadata, conv_state,
ssm_state)
hidden_states = self.mamba(hidden_states, attn_metadata,
mamba_cache_params)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
@@ -476,17 +469,14 @@ class JambaModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
mamba_cache_params: MambaCacheParams,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
kv_cache = None
current_ssm_state = None
current_conv_state = None
layer_mamba_cache_params = None
if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
self.config.attn_layer_period]
@@ -494,8 +484,8 @@ class JambaModel(nn.Module):
current_state_layer = i - (1 +
(i - self.config.attn_layer_offset)
// self.config.attn_layer_period)
current_ssm_state = ssm_state[current_state_layer]
current_conv_state = conv_state[current_state_layer]
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
current_state_layer)
hidden_states, residual = layer(
positions=positions,
@@ -503,9 +493,7 @@ class JambaModel(nn.Module):
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=residual,
conv_state=current_conv_state,
ssm_state=current_ssm_state,
)
mamba_cache_params=layer_mamba_cache_params)
hidden_states, _ = self.final_layernorm(hidden_states, residual)
return hidden_states
@@ -588,13 +576,16 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
*self._get_mamba_cache_shape())
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
input_ids, attn_metadata, **kwargs)
(
mamba_cache_tensors,
state_indices_tensor,
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
**kwargs)
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
mamba_cache_tensors[1],
state_indices_tensor)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_tensors[0],
mamba_cache_tensors[1])
attn_metadata, mamba_cache_params)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):