[5/N] pass the whole config to model (#9983)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-08 22:17:28 -08:00
committed by GitHub
parent 49d2a41a86
commit 1a95f10ee7
75 changed files with 583 additions and 654 deletions

View File

@@ -11,9 +11,8 @@ from vllm.utils import supports_kw
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolerOutput
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -39,10 +38,8 @@ class VllmModel(Protocol[C_co, T_co]):
def __init__(
self,
config: C_co,
*,
cache_config: Optional["CacheConfig"],
quant_config: Optional["QuantizationConfig"],
vllm_config: "VllmConfig",
prefix: str = "",
) -> None:
...
@@ -58,20 +55,7 @@ class VllmModel(Protocol[C_co, T_co]):
def _check_vllm_model_init(model: Union[Type[object], object]) -> bool:
model_init = model.__init__
vllm_kws = ("cache_config", "quant_config")
missing_kws = tuple(kw for kw in vllm_kws
if not supports_kw(model_init, kw))
if missing_kws and (isinstance(model, type)
and issubclass(model, nn.Module)):
logger.warning(
"The model (%s) is missing "
"vLLM-specific keywords from its initializer: %s",
model,
missing_kws,
)
return len(missing_kws) == 0
return supports_kw(model_init, "vllm_config")
def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: