[Models] Add remaining model PP support (#7168)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Murali Andoorveedu
2024-10-03 19:56:58 -07:00
committed by GitHub
parent 303d44790a
commit 0f6d7a9a34
69 changed files with 2585 additions and 1344 deletions

View File

@@ -2,7 +2,7 @@
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
@@ -10,15 +10,14 @@ from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -28,14 +27,16 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GLMAttention(nn.Module):
def __init__(
self,
config,
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
@@ -126,7 +127,7 @@ class GLMMLP(nn.Module):
def __init__(
self,
config,
config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -169,7 +170,7 @@ class GLMBlock(nn.Module):
def __init__(
self,
config,
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
@@ -240,9 +241,10 @@ class GLMTransformer(nn.Module):
def __init__(
self,
config,
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.post_layer_norm = config.post_layer_norm
@@ -251,10 +253,11 @@ class GLMTransformer(nn.Module):
self.num_layers = config.num_layers
# Transformer layers.
self.layers = nn.ModuleList([
GLMBlock(config, cache_config, quant_config)
for i in range(self.num_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
self.num_layers,
lambda prefix: GLMBlock(config, cache_config, quant_config),
prefix=f"{prefix}.layers",
)
if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
@@ -269,16 +272,16 @@ class GLMTransformer(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
for i in range(self.num_layers):
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
position_ids=position_ids,
kv_cache=kv_caches[i],
kv_cache=kv_caches[i - self.start_layer],
attn_metadata=attn_metadata,
)
# Final layer norm.
if self.post_layer_norm:
if get_pp_group().is_last_rank and self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
@@ -288,7 +291,7 @@ class ChatGLMModel(nn.Module):
def __init__(
self,
config,
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
@@ -305,6 +308,9 @@ class ChatGLMModel(nn.Module):
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size,
quant_config=quant_config)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
@@ -312,8 +318,12 @@ class ChatGLMModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids)
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
inputs_embeds = self.embedding(input_ids)
else:
inputs_embeds = intermediate_tensors["hidden_states"]
# Run encoder.
hidden_states = self.encoder(
@@ -322,10 +332,13 @@ class ChatGLMModel(nn.Module):
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states
class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
@@ -362,6 +375,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
@@ -370,9 +385,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
@@ -402,6 +417,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)