[Model] Pipeline Parallel Support for DeepSeek v2 (#6519)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
@@ -31,6 +31,7 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
|||||||
_PP_SUPPORTED_MODELS = [
|
_PP_SUPPORTED_MODELS = [
|
||||||
"AquilaModel",
|
"AquilaModel",
|
||||||
"AquilaForCausalLM",
|
"AquilaForCausalLM",
|
||||||
|
"DeepseekV2ForCausalLM",
|
||||||
"InternLMForCausalLM",
|
"InternLMForCausalLM",
|
||||||
"LlamaForCausalLM",
|
"LlamaForCausalLM",
|
||||||
"LLaMAForCausalLM",
|
"LLaMAForCausalLM",
|
||||||
|
|||||||
@@ -29,7 +29,8 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
from vllm.distributed import (get_pp_group,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
@@ -49,6 +50,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
|
|
||||||
|
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2MLP(nn.Module):
|
class DeepseekV2MLP(nn.Module):
|
||||||
|
|
||||||
@@ -59,17 +62,20 @@ class DeepseekV2MLP(nn.Module):
|
|||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
reduce_results: bool = True,
|
reduce_results: bool = True,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
hidden_size, [intermediate_size] * 2,
|
hidden_size, [intermediate_size] * 2,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj")
|
||||||
self.down_proj = RowParallelLinear(intermediate_size,
|
self.down_proj = RowParallelLinear(intermediate_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=reduce_results)
|
reduce_results=reduce_results,
|
||||||
|
prefix=f"{prefix}.down_proj")
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
"Only silu is supported for now.")
|
"Only silu is supported for now.")
|
||||||
@@ -88,6 +94,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
@@ -112,12 +119,14 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
use_grouped_topk=True,
|
use_grouped_topk=True,
|
||||||
num_expert_group=config.n_group,
|
num_expert_group=config.n_group,
|
||||||
topk_group=config.topk_group)
|
topk_group=config.topk_group,
|
||||||
|
prefix=f"{prefix}.experts")
|
||||||
|
|
||||||
self.gate = ReplicatedLinear(config.hidden_size,
|
self.gate = ReplicatedLinear(config.hidden_size,
|
||||||
config.n_routed_experts,
|
config.n_routed_experts,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=None)
|
quant_config=None,
|
||||||
|
prefix=f"{prefix}.gate")
|
||||||
if config.n_shared_experts is not None:
|
if config.n_shared_experts is not None:
|
||||||
intermediate_size = (config.moe_intermediate_size *
|
intermediate_size = (config.moe_intermediate_size *
|
||||||
config.n_shared_experts)
|
config.n_shared_experts)
|
||||||
@@ -172,10 +181,9 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
layer_idx=None,
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_idx = layer_idx
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.qk_nope_head_dim = qk_nope_head_dim
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
self.qk_rope_head_dim = qk_rope_head_dim
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
@@ -195,38 +203,44 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||||
self.q_lora_rank,
|
self.q_lora_rank,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.q_a_proj")
|
||||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
|
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
|
||||||
self.num_heads *
|
self.num_heads *
|
||||||
self.qk_head_dim,
|
self.qk_head_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.q_b_proj")
|
||||||
else:
|
else:
|
||||||
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
||||||
self.num_heads *
|
self.num_heads *
|
||||||
self.qk_head_dim,
|
self.qk_head_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.q_proj")
|
||||||
|
|
||||||
self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size,
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||||
self.kv_lora_rank +
|
self.hidden_size,
|
||||||
self.qk_rope_head_dim,
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
self.kv_b_proj = ColumnParallelLinear(
|
self.kv_b_proj = ColumnParallelLinear(
|
||||||
self.kv_lora_rank,
|
self.kv_lora_rank,
|
||||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.kv_b_proj")
|
||||||
# O projection.
|
# O projection.
|
||||||
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj")
|
||||||
rope_scaling['type'] = 'deepseek_yarn'
|
rope_scaling['type'] = 'deepseek_yarn'
|
||||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||||
rotary_dim=qk_rope_head_dim,
|
rotary_dim=qk_rope_head_dim,
|
||||||
@@ -308,7 +322,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
layer_idx: int,
|
prefix: str,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -318,6 +332,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
8192)
|
8192)
|
||||||
|
# DecoderLayers are created with `make_layers` which passes the prefix
|
||||||
|
# with the layer's index.
|
||||||
|
layer_idx = int(prefix.split(sep='.')[-1])
|
||||||
self.self_attn = DeepseekV2Attention(
|
self.self_attn = DeepseekV2Attention(
|
||||||
config=config,
|
config=config,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@@ -333,18 +350,23 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
layer_idx=layer_idx,
|
prefix=f"{prefix}.self_attn",
|
||||||
)
|
)
|
||||||
if (config.n_routed_experts is not None
|
if (config.n_routed_experts is not None
|
||||||
and layer_idx >= config.first_k_dense_replace
|
and layer_idx >= config.first_k_dense_replace
|
||||||
and layer_idx % config.moe_layer_freq == 0):
|
and layer_idx % config.moe_layer_freq == 0):
|
||||||
self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config)
|
self.mlp = DeepseekV2MoE(
|
||||||
|
config=config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.mlp = DeepseekV2MLP(
|
self.mlp = DeepseekV2MLP(
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
@@ -389,23 +411,34 @@ class DeepseekV2Model(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
if get_pp_group().is_first_rank:
|
||||||
config.vocab_size,
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.hidden_size,
|
config.vocab_size,
|
||||||
)
|
config.hidden_size,
|
||||||
self.layers = nn.ModuleList([
|
)
|
||||||
DeepseekV2DecoderLayer(config,
|
else:
|
||||||
layer_idx,
|
self.embed_tokens = PPMissingLayer()
|
||||||
cache_config=cache_config,
|
|
||||||
quant_config=quant_config)
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
for layer_idx in range(config.num_hidden_layers)
|
config.num_hidden_layers,
|
||||||
])
|
lambda prefix: DeepseekV2DecoderLayer(
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
config,
|
||||||
|
prefix,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
),
|
||||||
|
prefix=f"{prefix}.layers")
|
||||||
|
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
else:
|
||||||
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -413,14 +446,28 @@ class DeepseekV2Model(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
if get_pp_group().is_first_rank:
|
||||||
residual = None
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
for i in range(len(self.layers)):
|
residual = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
for i in range(self.start_layer, self.end_layer):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(positions, hidden_states,
|
hidden_states, residual = layer(positions, hidden_states,
|
||||||
kv_caches[i], attn_metadata,
|
kv_caches[i - self.start_layer],
|
||||||
residual)
|
attn_metadata, residual)
|
||||||
|
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual
|
||||||
|
})
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -436,7 +483,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = DeepseekV2Model(config, cache_config, quant_config)
|
self.model = DeepseekV2Model(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix="model")
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
@@ -452,7 +502,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
attn_metadata)
|
attn_metadata, intermediate_tensors)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(self, hidden_states: torch.Tensor,
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
@@ -469,6 +519,20 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
|
def make_empty_intermediate_tensors(
|
||||||
|
self, batch_size: int, dtype: torch.dtype,
|
||||||
|
device: torch.device) -> IntermediateTensors:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states":
|
||||||
|
torch.zeros((batch_size, self.config.hidden_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device),
|
||||||
|
"residual":
|
||||||
|
torch.zeros((batch_size, self.config.hidden_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device),
|
||||||
|
})
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
@@ -504,6 +568,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
@@ -514,6 +582,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param,
|
weight_loader(param,
|
||||||
@@ -527,6 +599,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
|||||||
Reference in New Issue
Block a user