[Refactor] Separate sequence and token pooling types (#32026)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -539,9 +539,12 @@ class ModelConfig:
|
||||
if getattr(self.pooler_config, k) is None:
|
||||
setattr(self.pooler_config, k, v)
|
||||
|
||||
default_pooling_type = self._model_info.default_pooling_type
|
||||
if self.pooler_config.pooling_type is None:
|
||||
self.pooler_config.pooling_type = default_pooling_type
|
||||
default_seq_pooling_type = self._model_info.default_seq_pooling_type
|
||||
if self.pooler_config.seq_pooling_type is None:
|
||||
self.pooler_config.seq_pooling_type = default_seq_pooling_type
|
||||
default_tok_pooling_type = self._model_info.default_tok_pooling_type
|
||||
if self.pooler_config.tok_pooling_type is None:
|
||||
self.pooler_config.tok_pooling_type = default_tok_pooling_type
|
||||
|
||||
self.dtype: torch.dtype = _get_and_verify_dtype(
|
||||
self.model,
|
||||
@@ -1543,8 +1546,8 @@ class ModelConfig:
|
||||
@property
|
||||
def attn_type(self) -> AttnTypeStr:
|
||||
if self.pooler_config is not None:
|
||||
pooling_type = self._model_info.default_pooling_type.lower()
|
||||
if pooling_type == "cls":
|
||||
seq_pooling_type = self._model_info.default_seq_pooling_type
|
||||
if seq_pooling_type == "CLS":
|
||||
return "encoder_only"
|
||||
else:
|
||||
is_causal = getattr(self.hf_config, "is_causal", True)
|
||||
@@ -1561,89 +1564,102 @@ class ModelConfig:
|
||||
@property
|
||||
def is_chunked_prefill_supported(self) -> bool:
|
||||
attn_type = self.attn_type
|
||||
if self.pooler_config is not None:
|
||||
|
||||
if pooler_config := self.pooler_config:
|
||||
# for pooling models
|
||||
if attn_type == "encoder_only":
|
||||
logger.debug(
|
||||
"Pooling models with bidirectional attn does not support "
|
||||
"chunked prefill."
|
||||
"Pooling models with bidirectional attn "
|
||||
"do not support chunked prefill."
|
||||
)
|
||||
return False
|
||||
elif attn_type == "decoder":
|
||||
pooling_type = self.pooler_config.pooling_type.lower()
|
||||
if pooling_type in ["mean", "step", "cls"]:
|
||||
|
||||
if attn_type == "decoder":
|
||||
if (
|
||||
pooler_config.seq_pooling_type in ("MEAN", "CLS")
|
||||
or pooler_config.tok_pooling_type == "STEP"
|
||||
):
|
||||
logger.debug(
|
||||
"Pooling models with %s pooling does not "
|
||||
"support chunked prefill.",
|
||||
pooling_type,
|
||||
"Pooling models with causal attn and %s/%s pooling "
|
||||
"do not support chunked prefill.",
|
||||
pooler_config.seq_pooling_type,
|
||||
pooler_config.tok_pooling_type,
|
||||
)
|
||||
return False
|
||||
elif pooling_type in ["all", "last"]:
|
||||
else:
|
||||
logger.debug(
|
||||
"Pooling models with causal attn and %s pooling support "
|
||||
"chunked prefill.",
|
||||
pooling_type,
|
||||
"Pooling models with causal attn and %s/%s pooling "
|
||||
"support chunked prefill.",
|
||||
pooler_config.seq_pooling_type,
|
||||
pooler_config.tok_pooling_type,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"{pooling_type=} not supported.")
|
||||
|
||||
# vllm currently does not have pooling models using hybrid,
|
||||
# attention_free or encoder_decoder attn types.
|
||||
return attn_type != "encoder_decoder"
|
||||
else:
|
||||
# for generative models
|
||||
if attn_type == "encoder_decoder":
|
||||
logger.debug("Encoder decoder models does not support chunked prefill.")
|
||||
logger.debug("Encoder decoder models do not support chunked prefill.")
|
||||
return False
|
||||
|
||||
logger.debug("Generative models support chunked prefill.")
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_prefix_caching_supported(self) -> bool:
|
||||
attn_type = self.attn_type
|
||||
if self.pooler_config is not None:
|
||||
|
||||
if pooler_config := self.pooler_config:
|
||||
# for pooling models
|
||||
if attn_type == "encoder_only":
|
||||
logger.debug(
|
||||
"Pooling models with bidirectional attn does not "
|
||||
"support prefix caching."
|
||||
"Pooling models with bidirectional attn "
|
||||
"do not support prefix caching."
|
||||
)
|
||||
return False
|
||||
elif attn_type == "decoder":
|
||||
pooling_type = self.pooler_config.pooling_type.lower()
|
||||
if pooling_type in ["mean", "step", "cls"]:
|
||||
|
||||
if attn_type == "decoder":
|
||||
if (
|
||||
pooler_config.seq_pooling_type in ("MEAN", "CLS")
|
||||
or pooler_config.tok_pooling_type == "STEP"
|
||||
):
|
||||
logger.debug(
|
||||
"Pooling models with %s pooling does not "
|
||||
"support prefix caching.",
|
||||
pooling_type,
|
||||
"Pooling models with causal attn and %s/%s pooling "
|
||||
"do not support prefix caching.",
|
||||
pooler_config.seq_pooling_type,
|
||||
pooler_config.tok_pooling_type,
|
||||
)
|
||||
return False
|
||||
elif pooling_type in ["all", "last"]:
|
||||
else:
|
||||
logger.debug(
|
||||
"Pooling models with causal attn and %s pooling support "
|
||||
"prefix caching.",
|
||||
pooling_type,
|
||||
"Pooling models with causal attn and %s/%s pooling "
|
||||
"support prefix caching.",
|
||||
pooler_config.seq_pooling_type,
|
||||
pooler_config.tok_pooling_type,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"{pooling_type=} not supported.")
|
||||
|
||||
# vllm currently does not have pooling models using hybrid,
|
||||
# attention_free or encoder_decoder attn types.
|
||||
return False
|
||||
else:
|
||||
# for generative models
|
||||
if attn_type == "hybrid":
|
||||
logger.debug(
|
||||
"Hybrid models does not support prefix caching since the feature "
|
||||
"Hybrid models do not support prefix caching since the feature "
|
||||
"is still experimental."
|
||||
)
|
||||
return False
|
||||
elif attn_type == "attention_free":
|
||||
logger.debug(
|
||||
"Attention free models does not support prefix caching since the "
|
||||
"Attention free models do not support prefix caching since the "
|
||||
"feature is still experimental."
|
||||
)
|
||||
return False
|
||||
elif attn_type == "encoder_decoder":
|
||||
logger.debug("Encoder decoder models does not support prefix caching.")
|
||||
logger.debug("Encoder decoder models do not support prefix caching.")
|
||||
return False
|
||||
else: # attn_type == "decoder"
|
||||
logger.debug("Generative models support prefix caching.")
|
||||
|
||||
Reference in New Issue
Block a user