[Refactor] Separate sequence and token pooling types (#32026)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-10 12:53:24 +08:00
committed by GitHub
parent 52d428295d
commit 583a90e005
42 changed files with 324 additions and 204 deletions

View File

@@ -539,9 +539,12 @@ class ModelConfig:
if getattr(self.pooler_config, k) is None:
setattr(self.pooler_config, k, v)
default_pooling_type = self._model_info.default_pooling_type
if self.pooler_config.pooling_type is None:
self.pooler_config.pooling_type = default_pooling_type
default_seq_pooling_type = self._model_info.default_seq_pooling_type
if self.pooler_config.seq_pooling_type is None:
self.pooler_config.seq_pooling_type = default_seq_pooling_type
default_tok_pooling_type = self._model_info.default_tok_pooling_type
if self.pooler_config.tok_pooling_type is None:
self.pooler_config.tok_pooling_type = default_tok_pooling_type
self.dtype: torch.dtype = _get_and_verify_dtype(
self.model,
@@ -1543,8 +1546,8 @@ class ModelConfig:
@property
def attn_type(self) -> AttnTypeStr:
if self.pooler_config is not None:
pooling_type = self._model_info.default_pooling_type.lower()
if pooling_type == "cls":
seq_pooling_type = self._model_info.default_seq_pooling_type
if seq_pooling_type == "CLS":
return "encoder_only"
else:
is_causal = getattr(self.hf_config, "is_causal", True)
@@ -1561,89 +1564,102 @@ class ModelConfig:
@property
def is_chunked_prefill_supported(self) -> bool:
attn_type = self.attn_type
if self.pooler_config is not None:
if pooler_config := self.pooler_config:
# for pooling models
if attn_type == "encoder_only":
logger.debug(
"Pooling models with bidirectional attn does not support "
"chunked prefill."
"Pooling models with bidirectional attn "
"do not support chunked prefill."
)
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["mean", "step", "cls"]:
if attn_type == "decoder":
if (
pooler_config.seq_pooling_type in ("MEAN", "CLS")
or pooler_config.tok_pooling_type == "STEP"
):
logger.debug(
"Pooling models with %s pooling does not "
"support chunked prefill.",
pooling_type,
"Pooling models with causal attn and %s/%s pooling "
"do not support chunked prefill.",
pooler_config.seq_pooling_type,
pooler_config.tok_pooling_type,
)
return False
elif pooling_type in ["all", "last"]:
else:
logger.debug(
"Pooling models with causal attn and %s pooling support "
"chunked prefill.",
pooling_type,
"Pooling models with causal attn and %s/%s pooling "
"support chunked prefill.",
pooler_config.seq_pooling_type,
pooler_config.tok_pooling_type,
)
return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return attn_type != "encoder_decoder"
else:
# for generative models
if attn_type == "encoder_decoder":
logger.debug("Encoder decoder models does not support chunked prefill.")
logger.debug("Encoder decoder models do not support chunked prefill.")
return False
logger.debug("Generative models support chunked prefill.")
return True
@property
def is_prefix_caching_supported(self) -> bool:
attn_type = self.attn_type
if self.pooler_config is not None:
if pooler_config := self.pooler_config:
# for pooling models
if attn_type == "encoder_only":
logger.debug(
"Pooling models with bidirectional attn does not "
"support prefix caching."
"Pooling models with bidirectional attn "
"do not support prefix caching."
)
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["mean", "step", "cls"]:
if attn_type == "decoder":
if (
pooler_config.seq_pooling_type in ("MEAN", "CLS")
or pooler_config.tok_pooling_type == "STEP"
):
logger.debug(
"Pooling models with %s pooling does not "
"support prefix caching.",
pooling_type,
"Pooling models with causal attn and %s/%s pooling "
"do not support prefix caching.",
pooler_config.seq_pooling_type,
pooler_config.tok_pooling_type,
)
return False
elif pooling_type in ["all", "last"]:
else:
logger.debug(
"Pooling models with causal attn and %s pooling support "
"prefix caching.",
pooling_type,
"Pooling models with causal attn and %s/%s pooling "
"support prefix caching.",
pooler_config.seq_pooling_type,
pooler_config.tok_pooling_type,
)
return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return False
else:
# for generative models
if attn_type == "hybrid":
logger.debug(
"Hybrid models does not support prefix caching since the feature "
"Hybrid models do not support prefix caching since the feature "
"is still experimental."
)
return False
elif attn_type == "attention_free":
logger.debug(
"Attention free models does not support prefix caching since the "
"Attention free models do not support prefix caching since the "
"feature is still experimental."
)
return False
elif attn_type == "encoder_decoder":
logger.debug("Encoder decoder models does not support prefix caching.")
logger.debug("Encoder decoder models do not support prefix caching.")
return False
else: # attn_type == "decoder"
logger.debug("Generative models support prefix caching.")