[Model] Pooling models default to using chunked prefill & prefix caching if supported. (#20930)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-08-12 00:41:37 +08:00
committed by GitHub
parent 16fb668b61
commit 84cf78acee
31 changed files with 452 additions and 261 deletions

View File

@@ -15,11 +15,10 @@ from torch import nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
PoolingType)
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type
from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, maybe_prefix
@@ -90,6 +89,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
return loader.load_weights(weights)
@default_pooling_type("ALL")
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -103,6 +103,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
{"encode": Pooler.for_encode(pooler_config)}, )
@default_pooling_type("STEP")
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -112,10 +113,5 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(
pooler_config,
default_pooling_type=PoolingType.STEP,
)
})
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)})