[Model] PP support for Mamba-like models (#10992)
Signed-off-by: mzusman <mor.zusmann@gmail.com>
This commit is contained in:
@@ -27,8 +27,8 @@ from vllm.transformers_utils.config import (
|
||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||
get_hf_text_config, get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
|
||||
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
||||
print_warning_once, random_uuid,
|
||||
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
||||
get_cpu_memory, print_warning_once, random_uuid,
|
||||
resolve_obj_by_qualname)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -284,6 +284,7 @@ class ModelConfig:
|
||||
self._verify_tokenizer_mode()
|
||||
|
||||
self.is_attention_free = self._init_attention_free()
|
||||
self.is_hybrid = self._init_is_hybrid()
|
||||
self.has_inner_state = self._init_has_inner_state()
|
||||
|
||||
if current_platform.is_neuron():
|
||||
@@ -340,6 +341,10 @@ class ModelConfig:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
return ModelRegistry.is_attention_free_model(architectures)
|
||||
|
||||
def _init_is_hybrid(self) -> bool:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
return ModelRegistry.is_hybrid_model(architectures)
|
||||
|
||||
def _init_has_inner_state(self) -> bool:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
return ModelRegistry.model_has_inner_state(architectures)
|
||||
@@ -669,26 +674,51 @@ class ModelConfig:
|
||||
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
|
||||
return num_heads // parallel_config.tensor_parallel_size
|
||||
|
||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||
def get_layers_start_end_indices(
|
||||
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_hidden_layers", 0)
|
||||
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
|
||||
pp_size = parallel_config.pipeline_parallel_size
|
||||
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
|
||||
return start, end
|
||||
|
||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||
start, end = self.get_layers_start_end_indices(parallel_config)
|
||||
return end - start
|
||||
|
||||
def get_num_attention_layers(self,
|
||||
parallel_config: "ParallelConfig") -> int:
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
def get_num_layers_by_block_type(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
block_type: LayerBlockType = LayerBlockType.attention,
|
||||
) -> int:
|
||||
# This function relies on 'layers_block_type' in hf_config,
|
||||
# for w/o this attribute, we will need to have workarounds like so
|
||||
attn_block_type = block_type == LayerBlockType.attention
|
||||
is_transformer = not self.is_hybrid and not self.is_attention_free
|
||||
start, end = self.get_layers_start_end_indices(parallel_config)
|
||||
|
||||
num_layers = self.get_num_layers(parallel_config)
|
||||
if is_transformer:
|
||||
# Handle the basic case first
|
||||
return end - start if attn_block_type else 0
|
||||
elif self.is_attention_free:
|
||||
# Attention free
|
||||
# Note that this code assumes there
|
||||
# is only one type of attention-free block type.
|
||||
return 0 if attn_block_type else end - start
|
||||
else:
|
||||
# Hybrid model
|
||||
layers_block_type_value = getattr(self.hf_config,
|
||||
"layers_block_type", None)
|
||||
if layers_block_type_value is None:
|
||||
raise ValueError("The model is an hybrid without a"
|
||||
"layers_block_type in the hf_config,"
|
||||
"cannot determine the num of "
|
||||
f"{block_type.value} layers")
|
||||
|
||||
# Transformers supports layers_block_type @property
|
||||
layers = getattr(self.hf_config, "layers_block_type",
|
||||
["attention"] * num_layers)
|
||||
return len([t for t in layers if t == "attention"])
|
||||
return sum(t == block_type.value
|
||||
for t in layers_block_type_value[start:end])
|
||||
|
||||
def get_multimodal_config(self) -> "MultiModalConfig":
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user