[BugFix][Model] Jamba - Handle aborted requests, Add tests and fix cleanup bug (#6425)

Co-authored-by: Mor Zusman <morz@ai21.com>
This commit is contained in:
Mor Zusman
2024-07-16 04:32:55 +03:00
committed by GitHub
parent d6f3b3d5c4
commit 9ad32dacd9
5 changed files with 176 additions and 24 deletions

View File

@@ -13,7 +13,7 @@ from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
@@ -32,10 +32,12 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import HasInnerState
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)
KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -612,7 +614,7 @@ class JambaModel(nn.Module):
return hidden_states
class JambaForCausalLM(nn.Module):
class JambaForCausalLM(nn.Module, HasInnerState):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -640,9 +642,11 @@ class JambaForCausalLM(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
) -> None:
super().__init__()
self.config = config
self.scheduler_config = scheduler_config
self.model = JambaModel(config,
cache_config=cache_config,
quant_config=quant_config,
@@ -689,6 +693,8 @@ class JambaForCausalLM(nn.Module):
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
batch_size = input_ids.shape[0]
if attn_metadata.prefill_metadata:
batch_size = len(request_ids_to_seq_ids)
@@ -696,9 +702,8 @@ class JambaForCausalLM(nn.Module):
current_seqlen_agnostic_cache,
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
batch_size)
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
batch_size,
finished_requests_ids)
else:
# CUDA graph capturing runs
current_seqlen_agnostic_cache, indices = (
@@ -760,10 +765,15 @@ class JambaForCausalLM(nn.Module):
return indices_for_current_run
def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int,
finished_requests_ids: List[str]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
indices_for_current_run = []
for request_id, seqs_id in request_ids_to_seq_ids.items():
if request_id in finished_requests_ids:
# Do not allocate cache for requests that run
# and finish right after
continue
indices_for_current_run += self._assign_seq_id_to_mamba_cache(
request_id, seqs_id)
## Pad the batch in case of running batch that was not captured via CG
@@ -787,16 +797,17 @@ class JambaForCausalLM(nn.Module):
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
cg_batch_size = input_buffers['input_ids'].shape[0]
(
current_mamba_cache,
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
cg_batch_size)
cg_batch_size,
finished_requests_ids)
self.current_indices = indices
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
for input_buffer, current_cache_buffer in zip(
input_buffers["seqlen_agnostic_capture_inputs"],
@@ -860,9 +871,12 @@ class JambaForCausalLM(nn.Module):
layers_type = self.config.layers_block_type
mamba_layers = sum(
[layer_type == "mamba" for layer_type in layers_type])
max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10
max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config else
max(_BATCH_SIZES_TO_CAPTURE)) + 10
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
assert conv_state_shape is not None and temporal_state_shape is not None
for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
conv_state_shape,