[5/N] pass the whole config to model (#9983)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -11,9 +11,8 @@ from vllm.utils import supports_kw
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import PoolerOutput
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@@ -39,10 +38,8 @@ class VllmModel(Protocol[C_co, T_co]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: C_co,
|
||||
*,
|
||||
cache_config: Optional["CacheConfig"],
|
||||
quant_config: Optional["QuantizationConfig"],
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@@ -58,20 +55,7 @@ class VllmModel(Protocol[C_co, T_co]):
|
||||
|
||||
def _check_vllm_model_init(model: Union[Type[object], object]) -> bool:
|
||||
model_init = model.__init__
|
||||
vllm_kws = ("cache_config", "quant_config")
|
||||
missing_kws = tuple(kw for kw in vllm_kws
|
||||
if not supports_kw(model_init, kw))
|
||||
|
||||
if missing_kws and (isinstance(model, type)
|
||||
and issubclass(model, nn.Module)):
|
||||
logger.warning(
|
||||
"The model (%s) is missing "
|
||||
"vLLM-specific keywords from its initializer: %s",
|
||||
model,
|
||||
missing_kws,
|
||||
)
|
||||
|
||||
return len(missing_kws) == 0
|
||||
return supports_kw(model_init, "vllm_config")
|
||||
|
||||
|
||||
def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user