Rename clashing method names for vLLM model protocol (#27583)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -41,24 +41,19 @@ T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
|
||||
class VllmModel(Protocol[T_co]):
|
||||
"""The interface required for all models in vLLM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None: ...
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: ...
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply token embeddings to `input_ids`."""
|
||||
...
|
||||
if hasattr(self, "get_input_embeddings"):
|
||||
logger.warning_once(
|
||||
"`get_input_embeddings` for vLLM models is deprecated and will be "
|
||||
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
|
||||
"this method to `embed_input_ids`."
|
||||
)
|
||||
return self.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> T_co: ...
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor) -> T_co: ...
|
||||
|
||||
|
||||
def _check_vllm_model_init(model: type[object] | object) -> bool:
|
||||
@@ -66,11 +61,19 @@ def _check_vllm_model_init(model: type[object] | object) -> bool:
|
||||
return supports_kw(model_init, "vllm_config")
|
||||
|
||||
|
||||
def _check_vllm_model_get_input_embeddings(model: type[object] | object) -> bool:
|
||||
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
|
||||
if not callable(model_get_input_embeddings):
|
||||
def _check_vllm_model_embed_input_ids(model: type[object] | object) -> bool:
|
||||
model_embed_input_ids = getattr(model, "embed_input_ids", None)
|
||||
if not callable(model_embed_input_ids):
|
||||
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
|
||||
if callable(model_get_input_embeddings):
|
||||
logger.warning(
|
||||
"`get_input_embeddings` for vLLM models is deprecated and will be "
|
||||
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
|
||||
"this method to `embed_input_ids`."
|
||||
)
|
||||
model.embed_input_ids = model_get_input_embeddings
|
||||
logger.warning(
|
||||
"The model (%s) is missing the `get_input_embeddings` method.",
|
||||
"The model (%s) is missing the `embed_input_ids` method.",
|
||||
model,
|
||||
)
|
||||
return False
|
||||
@@ -110,7 +113,7 @@ def is_vllm_model(
|
||||
) -> TypeIs[type[VllmModel]] | TypeIs[VllmModel]:
|
||||
return (
|
||||
_check_vllm_model_init(model)
|
||||
and _check_vllm_model_get_input_embeddings(model)
|
||||
and _check_vllm_model_embed_input_ids(model)
|
||||
and _check_vllm_model_forward(model)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user