[Model] PP support for Mamba-like models (#10992)
Signed-off-by: mzusman <mor.zusmann@gmail.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from transformers import MambaConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
@@ -18,13 +19,16 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||
IsAttentionFree)
|
||||
IsAttentionFree, SupportsPP)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType
|
||||
|
||||
from .utils import maybe_prefix
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@@ -95,15 +99,17 @@ class MambaModel(nn.Module):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
decoder_layers = []
|
||||
for i in range(config.num_hidden_layers):
|
||||
decoder_layers.append(
|
||||
MambaDecoderLayer(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config))
|
||||
self.layers = nn.ModuleList(decoder_layers)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MambaDecoderLayer(
|
||||
config, cache_config=cache_config, quant_config=quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.norm_f = RMSNorm(config.hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings(input_ids)
|
||||
@@ -114,29 +120,40 @@ class MambaModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(i))
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
||||
i - self.start_layer))
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.norm_f(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -148,7 +165,9 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.backbone = MambaModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "backbone"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
@@ -174,6 +193,9 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.backbone.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.backbone.get_input_embeddings(input_ids)
|
||||
|
||||
@@ -189,9 +211,12 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
max_batch_size = (VllmConfig.get_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs) if self.scheduler_config
|
||||
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
|
||||
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.lm_head.weight.dtype, self.config.num_hidden_layers,
|
||||
max_batch_size, *self._get_mamba_cache_shape())
|
||||
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
|
||||
*self._get_mamba_cache_shape())
|
||||
|
||||
(
|
||||
mamba_cache_tensors,
|
||||
@@ -204,7 +229,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
state_indices_tensor)
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, attn_metadata,
|
||||
mamba_cache_params, inputs_embeds)
|
||||
mamba_cache_params, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -252,6 +278,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
|
||||
Reference in New Issue
Block a user