Rename clashing method names for vLLM model protocol (#27583)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-11-13 03:14:33 +00:00
committed by GitHub
parent 3226283461
commit 97d1c99302
164 changed files with 574 additions and 583 deletions

View File

@@ -41,24 +41,19 @@ T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
class VllmModel(Protocol[T_co]):
"""The interface required for all models in vLLM."""
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None: ...
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: ...
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Apply token embeddings to `input_ids`."""
...
if hasattr(self, "get_input_embeddings"):
logger.warning_once(
"`get_input_embeddings` for vLLM models is deprecated and will be "
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
"this method to `embed_input_ids`."
)
return self.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> T_co: ...
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor) -> T_co: ...
def _check_vllm_model_init(model: type[object] | object) -> bool:
@@ -66,11 +61,19 @@ def _check_vllm_model_init(model: type[object] | object) -> bool:
return supports_kw(model_init, "vllm_config")
def _check_vllm_model_get_input_embeddings(model: type[object] | object) -> bool:
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
if not callable(model_get_input_embeddings):
def _check_vllm_model_embed_input_ids(model: type[object] | object) -> bool:
model_embed_input_ids = getattr(model, "embed_input_ids", None)
if not callable(model_embed_input_ids):
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
if callable(model_get_input_embeddings):
logger.warning(
"`get_input_embeddings` for vLLM models is deprecated and will be "
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
"this method to `embed_input_ids`."
)
model.embed_input_ids = model_get_input_embeddings
logger.warning(
"The model (%s) is missing the `get_input_embeddings` method.",
"The model (%s) is missing the `embed_input_ids` method.",
model,
)
return False
@@ -110,7 +113,7 @@ def is_vllm_model(
) -> TypeIs[type[VllmModel]] | TypeIs[VllmModel]:
return (
_check_vllm_model_init(model)
and _check_vllm_model_get_input_embeddings(model)
and _check_vllm_model_embed_input_ids(model)
and _check_vllm_model_forward(model)
)