[Misc] Use config definitions from Transformers library (#21913)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -29,7 +29,7 @@ from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
@@ -100,7 +100,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
@@ -221,7 +221,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
@@ -373,7 +373,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
@@ -538,7 +538,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
@@ -973,7 +973,10 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
|
||||
# Compatibility with
|
||||
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py
|
||||
def get_spec_layer_idx_from_weight_name(config: Union[DeepseekV2Config,
|
||||
DeepseekV3Config],
|
||||
weight_name: str) -> Optional[int]:
|
||||
if (hasattr(config, "num_nextn_predict_layers")
|
||||
and config.num_nextn_predict_layers > 0):
|
||||
|
||||
Reference in New Issue
Block a user