[Model] Support Mamba (#6484)

This commit is contained in:
Tyler Michael Smith
2024-10-11 11:40:06 -04:00
committed by GitHub
parent df3dcdf49d
commit 7342a7d7f8
29 changed files with 1603 additions and 343 deletions

View File

@@ -1,18 +1,16 @@
# coding=utf-8
"""Inference-only Jamba model."""
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from torch.nn.parameter import Parameter
from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -29,7 +27,9 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheManager
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
@@ -99,16 +99,6 @@ class JambaMambaMixer(nn.Module):
bias=True,
skip_bias_add=True)
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
param.data.copy_(
loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
dim=0)[tp_rank])
def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
weight_loader(param, -torch.exp(loaded_weight.float()))
tp_size = get_tensor_model_parallel_world_size()
self.A = nn.Parameter(
torch.empty(
@@ -118,8 +108,10 @@ class JambaMambaMixer(nn.Module):
))
self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
set_weight_attrs(self.D, {"weight_loader": weight_loader})
set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
a_weight_loader = composed_weight_loader(
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
self.out_proj = RowParallelLinear(
self.intermediate_size,
@@ -571,10 +563,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
if not lora_config else lora_config.lora_vocab_padding_size,
)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
@@ -586,203 +576,36 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs):
if not self.mamba_cache:
self._prepare_mamba_cache()
if self.mamba_cache is None:
max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
mamba_cache = self._release_finished_and_prepare_mamba_cache(
finished_requests_ids, request_ids_to_seq_ids)
else:
# CUDA graph capturing runs
mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]
layers_type = self.config.layers_block_type
num_mamba_layers = sum(
[layer_type == "mamba" for layer_type in layers_type])
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
*self._get_mamba_cache_shape())
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
input_ids, attn_metadata, **kwargs)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache[0],
mamba_cache[1])
attn_metadata, mamba_cache_tensors[0],
mamba_cache_tensors[1])
return hidden_states
def _swap_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, [to_index,from_index]] = \
cache_t[:, [from_index,to_index]]
def _copy_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def _move_out_if_already_occupied(self, index: int,
all_occupied_indices: List[int]):
if index in all_occupied_indices:
first_free_index = self._first_free_index_in_mamba_cache()
# In case occupied, move the occupied to a new empty block
self._move_cache_index_and_mappings(from_index=index,
to_index=first_free_index)
def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
seq_id: int,
destination_index: int):
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
all_occupied_indices = self._get_all_occupied_indices()
if cur_rid not in self.mamba_cache_indices_mapping:
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
self.mamba_cache_indices_mapping[cur_rid] = {
seq_id: destination_index
}
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened now we only need to copy the already
# existing cache into the siblings seq_ids caches
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
index_exists = list(seq_ids2indices.values())[0]
# case of decoding n>1, copy prefill cache to decoding indices
self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index
else:
# already exists
cache_index_already_exists = self.mamba_cache_indices_mapping[
cur_rid][seq_id]
if cache_index_already_exists != destination_index:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self._swap_pair_indices_and_mappings(
from_index=cache_index_already_exists,
to_index=destination_index)
def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]
) -> Tuple[torch.Tensor, torch.Tensor]:
running_indices = []
request_ids_to_seq_ids_flatten = [
(req_id, seq_id)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
batch_size = len(request_ids_to_seq_ids_flatten)
for dest_index, (request_id,
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
if request_id in finished_requests_ids:
# Do not allocate cache index for requests that run
# and finish right after
continue
self._assign_seq_id_to_mamba_cache_in_specific_dest(
request_id, seq_id, dest_index)
running_indices.append(dest_index)
self._clean_up_first_bs_blocks(batch_size, running_indices)
conv_state = self.mamba_cache[0][:, :batch_size]
temporal_state = self.mamba_cache[1][:, :batch_size]
return (conv_state, temporal_state)
def _get_all_occupied_indices(self):
return [
cache_idx
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
for cache_idx in seq_ids2indices.values()
]
def _clean_up_first_bs_blocks(self, batch_size: int,
indices_for_current_run: List[int]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices = range(batch_size)
max_possible_batch_size = self.mamba_cache[0].shape[1]
for destination_index in destination_indices:
if destination_index in self._get_all_occupied_indices() and \
destination_index not in indices_for_current_run:
# move not running indices outside of the batch
all_other_indices = list(
range(batch_size, max_possible_batch_size))
first_avail_index = self._first_free_index_in_mamba_cache(
all_other_indices)
self._swap_indices(from_index=destination_index,
to_index=first_avail_index)
def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
self._update_mapping_index(from_index=from_index, to_index=to_index)
def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
self._swap_mapping_index(from_index=from_index, to_index=to_index)
def _swap_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
elif to_index == index:
seq_ids2index.update({seq_id: from_index})
def _update_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
return
def _release_finished_and_prepare_mamba_cache(
self, finished_requests_ids,
request_ids_to_seq_ids) -> Tuple[torch.Tensor, torch.Tensor]:
self._release_mamba_cache(finished_requests_ids)
return self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
finished_requests_ids)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant Mamba cache into the CUDA graph input buffer
that was provided during the capture runs
(JambaForCausalLM.mamba_gc_cache_buffer).
"""
self._release_finished_and_prepare_mamba_cache(
kwargs["finished_requests_ids"], kwargs["request_ids_to_seq_ids"])
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping:
self.mamba_cache_indices_mapping.pop(req_id)
def _first_free_index_in_mamba_cache(
self, indices_range: Optional[List[int]] = None) -> int:
assert self.mamba_cache is not None
if indices_range is None:
max_possible_batch_size = self.mamba_cache[0].shape[1]
indices_range = list(range(max_possible_batch_size))
all_occupied_indices = self._get_all_occupied_indices()
for i in indices_range:
if i not in all_occupied_indices:
return i
raise Exception("Couldn't find a free spot in the mamba cache! This"
"should never happen")
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self
) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]:
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
conv_state_shape = (
@@ -790,31 +613,11 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
self.config.mamba_d_conv - 1,
)
temporal_state_shape = (
self.config.mamba_expand * self.config.hidden_size // world_size,
self.config.mamba_expand * hidden_size // world_size,
self.config.mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def _prepare_mamba_cache(self):
dtype = self.lm_head.weight.dtype
layers_type = self.config.layers_block_type
mamba_layers = sum(
[layer_type == "mamba" for layer_type in layers_type])
max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config else
max(_BATCH_SIZES_TO_CAPTURE) + 2)
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
assert conv_state_shape is not None and temporal_state_shape is not None
self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) +
conv_state_shape,
dtype=dtype,
device="cuda"),
torch.empty(size=(mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=dtype,
device="cuda"))
def compute_logits(
self,
hidden_states: torch.Tensor,