[Core] Pipeline Parallel Support (#4412)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
committed by
GitHub
parent
15aba081f3
commit
c5832d2ae9
@@ -17,7 +17,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -25,7 +25,9 @@ from transformers import GPT2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tensor_model_parallel_world_size)
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@@ -38,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
class GPT2Attention(nn.Module):
|
||||
@@ -181,10 +183,18 @@ class GPT2Model(nn.Module):
|
||||
self.embed_dim = config.hidden_size
|
||||
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList([
|
||||
GPT2Block(config, cache_config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer = get_pp_indices(
|
||||
config.num_hidden_layers,
|
||||
get_pp_group().rank_in_group,
|
||||
get_pp_group().world_size)
|
||||
self.h = nn.ModuleList(
|
||||
[nn.Identity() for _ in range(self.start_layer)] + [
|
||||
GPT2Block(config, cache_config, quant_config)
|
||||
for _ in range(self.start_layer, self.end_layer)
|
||||
] + [
|
||||
nn.Identity()
|
||||
for _ in range(self.end_layer, config.num_hidden_layers)
|
||||
])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
@@ -193,14 +203,24 @@ class GPT2Model(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
for i in range(len(self.h)):
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
|
||||
hidden_states = layer(hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@@ -228,9 +248,10 @@ class GPT2LMHeadModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
@@ -247,6 +268,16 @@ class GPT2LMHeadModel(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in weights:
|
||||
@@ -260,16 +291,19 @@ class GPT2LMHeadModel(nn.Module):
|
||||
continue
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
param = params_dict[name]
|
||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||
# Because of this, we need to transpose the weights.
|
||||
# Note(zhuohan): the logic below might break quantized models.
|
||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||
if conv1d_weight_name not in name:
|
||||
continue
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
loaded_weight = loaded_weight.t()
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
try:
|
||||
param = params_dict[name]
|
||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||
# Because of this, we need to transpose the weights.
|
||||
# Note(zhuohan): the logic below might break quantized models.
|
||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||
if conv1d_weight_name not in name:
|
||||
continue
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
loaded_weight = loaded_weight.t()
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user