[Models] Add remaining model PP support (#7168)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Murali Andoorveedu
2024-10-03 19:56:58 -07:00
committed by GitHub
parent 303d44790a
commit 0f6d7a9a34
69 changed files with 2585 additions and 1344 deletions

View File

@@ -17,7 +17,7 @@
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
@@ -25,15 +25,14 @@ from transformers import BloomConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
@@ -41,6 +40,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
@@ -222,6 +225,7 @@ class BloomModel(nn.Module):
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.embed_dim = config.hidden_size
@@ -235,13 +239,16 @@ class BloomModel(nn.Module):
self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks
self.h = nn.ModuleList([
BloomBlock(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: BloomBlock(config, cache_config, quant_config),
prefix=f"{prefix}.h")
# Final Layer Norm
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
@@ -249,22 +256,29 @@ class BloomModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
for i in range(len(self.h)):
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
class BloomForCausalLM(nn.Module):
class BloomForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
@@ -284,6 +298,8 @@ class BloomForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
@@ -292,9 +308,9 @@ class BloomForCausalLM(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
@@ -321,6 +337,8 @@ class BloomForCausalLM(nn.Module):
continue
if not name.startswith("transformer."):
name = "transformer." + name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name: