[Misc] Rename embedding classes to pooling (#10801)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-01 14:36:51 +08:00
committed by GitHub
parent f877a7d12a
commit d2f058e76c
25 changed files with 166 additions and 123 deletions

View File

@@ -4,7 +4,7 @@ from typing import Any, TypeVar
import torch
import torch.nn as nn
from .interfaces_base import VllmModelForEmbedding, is_embedding_model
from .interfaces_base import VllmModelForPooling, is_pooling_model
_T = TypeVar("_T", bound=type[nn.Module])
@@ -12,7 +12,7 @@ _T = TypeVar("_T", bound=type[nn.Module])
def as_embedding_model(cls: _T) -> _T:
"""Subclass an existing vLLM model to support embeddings."""
# Avoid modifying existing embedding models
if is_embedding_model(cls):
if is_pooling_model(cls):
return cls
# Lazy import
@@ -23,7 +23,7 @@ def as_embedding_model(cls: _T) -> _T:
from .utils import AutoWeightsLoader, WeightsMapper
class ModelForEmbedding(cls, VllmModelForEmbedding):
class ModelForEmbedding(cls, VllmModelForPooling):
def __init__(
self,