[Model] PP support for Mamba-like models (#10992)
Signed-off-by: mzusman <mor.zusmann@gmail.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
@@ -25,9 +26,12 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType
|
||||
|
||||
from .interfaces import HasInnerState, SupportsLoRA
|
||||
from .utils import maybe_prefix
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@@ -281,16 +285,24 @@ class JambaModel(nn.Module):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
decoder_layers = []
|
||||
for i in range(config.num_hidden_layers):
|
||||
layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
|
||||
decoder_layers.append(
|
||||
layer_class(config,
|
||||
layer_idx=i,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{i}"))
|
||||
self.layers = nn.ModuleList(decoder_layers)
|
||||
def get_layer(prefix: str):
|
||||
layer_idx = int(prefix.rsplit(".", 1)[1])
|
||||
layer_class = ALL_DECODER_LAYER_TYPES[
|
||||
config.layers_block_type[layer_idx]]
|
||||
return layer_class(
|
||||
config,
|
||||
layer_idx,
|
||||
cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
self.final_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
@@ -304,26 +316,34 @@ class JambaModel(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
kv_cache_index = 0
|
||||
mamba_cache_index = 0
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
kv_cache = None
|
||||
layer_mamba_cache_params = None
|
||||
if isinstance(layer, JambaAttentionDecoderLayer):
|
||||
kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
|
||||
self.config.attn_layer_period]
|
||||
kv_cache = kv_caches[kv_cache_index]
|
||||
kv_cache_index += 1
|
||||
if isinstance(layer, JambaMambaDecoderLayer):
|
||||
current_state_layer = i - (1 +
|
||||
(i - self.config.attn_layer_offset)
|
||||
// self.config.attn_layer_period)
|
||||
current_state_layer = mamba_cache_index
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||
current_state_layer)
|
||||
mamba_cache_index += 1
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
@@ -332,11 +352,17 @@ class JambaModel(nn.Module):
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
mamba_cache_params=layer_mamba_cache_params)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states, _ = self.final_layernorm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
IsHybrid):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@@ -368,6 +394,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.model = JambaModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
@@ -390,6 +418,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
@@ -406,10 +437,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
self.scheduler_config.max_num_seqs) if self.scheduler_config
|
||||
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
|
||||
|
||||
layers_type = self.config.layers_block_type
|
||||
num_mamba_layers = sum(
|
||||
[layer_type == "mamba" for layer_type in layers_type])
|
||||
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
|
||||
*self._get_mamba_cache_shape())
|
||||
@@ -423,7 +452,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
state_indices_tensor)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, mamba_cache_params,
|
||||
inputs_embeds)
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
@@ -504,8 +533,12 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@@ -520,6 +553,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
@@ -533,6 +568,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
|
||||
Reference in New Issue
Block a user