[Models] Add remaining model PP support (#7168)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai> Signed-off-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
committed by
GitHub
parent
303d44790a
commit
0f6d7a9a34
@@ -22,7 +22,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -30,7 +30,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@@ -41,8 +41,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -52,7 +51,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
class MiniCPMMoE(nn.Module):
|
||||
@@ -264,7 +265,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
@@ -346,10 +347,11 @@ class MiniCPMModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -365,15 +367,24 @@ class MiniCPMModel(nn.Module):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
self._init_layers()
|
||||
self._init_layers(prefix, config, cache_config, quant_config)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], self.config.hidden_size))
|
||||
|
||||
def _init_layers(self):
|
||||
self.layers = nn.ModuleList([
|
||||
MiniCPMDecoderLayer(self.config, self.cache_config,
|
||||
self.quant_config)
|
||||
for _ in range(self.config.num_hidden_layers)
|
||||
])
|
||||
def _init_layers(
|
||||
self,
|
||||
prefix: str,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig],
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MiniCPMDecoderLayer(config, cache_config,
|
||||
quant_config),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
embedding = self.embed_tokens(input_ids)
|
||||
@@ -387,27 +398,36 @@ class MiniCPMModel(nn.Module):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@@ -470,6 +490,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.logits_processor = LogitsProcessor(unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def _init_model(self):
|
||||
self.model = MiniCPMModel(config=self.config,
|
||||
@@ -484,7 +506,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
@@ -548,6 +570,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@@ -557,6 +581,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
@@ -568,6 +594,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
Reference in New Issue
Block a user