[Misc] Rename embedding classes to pooling (#10801)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Any, TypeVar
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .interfaces_base import VllmModelForEmbedding, is_embedding_model
|
||||
from .interfaces_base import VllmModelForPooling, is_pooling_model
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
@@ -12,7 +12,7 @@ _T = TypeVar("_T", bound=type[nn.Module])
|
||||
def as_embedding_model(cls: _T) -> _T:
|
||||
"""Subclass an existing vLLM model to support embeddings."""
|
||||
# Avoid modifying existing embedding models
|
||||
if is_embedding_model(cls):
|
||||
if is_pooling_model(cls):
|
||||
return cls
|
||||
|
||||
# Lazy import
|
||||
@@ -23,7 +23,7 @@ def as_embedding_model(cls: _T) -> _T:
|
||||
|
||||
from .utils import AutoWeightsLoader, WeightsMapper
|
||||
|
||||
class ModelForEmbedding(cls, VllmModelForEmbedding):
|
||||
class ModelForEmbedding(cls, VllmModelForPooling):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user