[Model] Update pooling model interface (#21058)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import msgspec
|
||||
|
||||
@@ -15,24 +15,31 @@ class PoolingParams(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""API parameters for pooling models. This is currently a placeholder.
|
||||
"""API parameters for pooling models. This
|
||||
|
||||
Attributes:
|
||||
dimensions: Reduce the dimensions of embeddings
|
||||
if model support matryoshka representation.
|
||||
additional_data: Any additional data needed for pooling.
|
||||
"""
|
||||
|
||||
dimensions: Optional[int] = None
|
||||
|
||||
use_cross_encoder: bool = False
|
||||
additional_data: Optional[Any] = None
|
||||
"""Internal use only."""
|
||||
|
||||
logits_processing_needs_token_ids: bool = False
|
||||
"""Internal use only."""
|
||||
|
||||
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
def clone(self) -> "PoolingParams":
|
||||
"""Returns a deep copy of the PoolingParams instance."""
|
||||
return PoolingParams(dimensions=self.dimensions,
|
||||
use_cross_encoder=self.use_cross_encoder,
|
||||
additional_data=self.additional_data)
|
||||
return PoolingParams(
|
||||
dimensions=self.dimensions,
|
||||
use_cross_encoder=self.use_cross_encoder,
|
||||
logits_processing_needs_token_ids=self.
|
||||
logits_processing_needs_token_ids,
|
||||
)
|
||||
|
||||
def verify(self, model_config: "ModelConfig") -> None:
|
||||
if self.dimensions is not None:
|
||||
@@ -54,10 +61,12 @@ class PoolingParams(
|
||||
raise ValueError("Dimensions must be greater than 0")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"PoolingParams("
|
||||
f"dimensions={self.dimensions}, "
|
||||
f"use_cross_encoder={self.use_cross_encoder}, "
|
||||
f"additional_metadata={self.additional_data})")
|
||||
return (
|
||||
f"PoolingParams("
|
||||
f"dimensions={self.dimensions}, "
|
||||
f"use_cross_encoder={self.use_cross_encoder}, "
|
||||
f"logits_processing_needs_token_ids={self.logits_processing_needs_token_ids})"
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\
|
||||
|
||||
Reference in New Issue
Block a user