[core][distributed] simplify code to support pipeline parallel (#6406)
This commit is contained in:
@@ -27,7 +27,6 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tensor_model_parallel_world_size)
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@@ -42,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .utils import is_pp_missing_parameter, make_layers
|
||||
|
||||
|
||||
class GPT2Attention(nn.Module):
|
||||
|
||||
@@ -183,18 +184,9 @@ class GPT2Model(nn.Module):
|
||||
self.embed_dim = config.hidden_size
|
||||
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.start_layer, self.end_layer = get_pp_indices(
|
||||
self.start_layer, self.end_layer, self.h = make_layers(
|
||||
config.num_hidden_layers,
|
||||
get_pp_group().rank_in_group,
|
||||
get_pp_group().world_size)
|
||||
self.h = nn.ModuleList(
|
||||
[nn.Identity() for _ in range(self.start_layer)] + [
|
||||
GPT2Block(config, cache_config, quant_config)
|
||||
for _ in range(self.start_layer, self.end_layer)
|
||||
] + [
|
||||
nn.Identity()
|
||||
for _ in range(self.end_layer, config.num_hidden_layers)
|
||||
])
|
||||
lambda: GPT2Block(config, cache_config, quant_config))
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
@@ -291,19 +283,20 @@ class GPT2LMHeadModel(nn.Module):
|
||||
continue
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
try:
|
||||
param = params_dict[name]
|
||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||
# Because of this, we need to transpose the weights.
|
||||
# Note(zhuohan): the logic below might break quantized models.
|
||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||
if conv1d_weight_name not in name:
|
||||
continue
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
loaded_weight = loaded_weight.t()
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
except KeyError:
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||
# Because of this, we need to transpose the weights.
|
||||
# Note(zhuohan): the logic below might break quantized models.
|
||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||
if conv1d_weight_name not in name:
|
||||
continue
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
loaded_weight = loaded_weight.t()
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
Reference in New Issue
Block a user