[core][distributed] simplify code to support pipeline parallel (#6406)

This commit is contained in:
youkaichao
2024-07-14 21:20:51 -07:00
committed by GitHub
parent 44874a0bf9
commit 69672f116c
5 changed files with 107 additions and 61 deletions

View File

@@ -27,7 +27,6 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_world_size)
from vllm.distributed.utils import get_pp_indices
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
@@ -42,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from .utils import is_pp_missing_parameter, make_layers
class GPT2Attention(nn.Module):
@@ -183,18 +184,9 @@ class GPT2Model(nn.Module):
self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer = get_pp_indices(
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)
self.h = nn.ModuleList(
[nn.Identity() for _ in range(self.start_layer)] + [
GPT2Block(config, cache_config, quant_config)
for _ in range(self.start_layer, self.end_layer)
] + [
nn.Identity()
for _ in range(self.end_layer, config.num_hidden_layers)
])
lambda: GPT2Block(config, cache_config, quant_config))
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
@@ -291,19 +283,20 @@ class GPT2LMHeadModel(nn.Module):
continue
if not name.startswith("transformer."):
name = "transformer." + name
try:
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
except KeyError:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)