[Model][6/N] Improve all pooling task | Support chunked prefill with ALL pooling (#27145)

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
wang.yuqi
2025-12-04 21:44:15 +08:00
committed by GitHub
parent 1b7c7f5159
commit 74c4d80c6c
15 changed files with 224 additions and 93 deletions

View File

@@ -127,14 +127,14 @@ class PoolingMethod(nn.Module, ABC):
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor:
) -> PoolerOutput:
raise NotImplementedError
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor] | torch.Tensor:
) -> PoolerOutput:
pooling_cursor = pooling_metadata.pooling_cursor
return self.forward_all(hidden_states, pooling_cursor)
@@ -147,7 +147,7 @@ class CLSPool(PoolingMethod):
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor:
) -> PoolerOutput:
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with CLS pooling"
)
@@ -163,27 +163,65 @@ class LastPool(PoolingMethod):
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor:
) -> PoolerOutput:
return hidden_states[pooling_cursor.last_token_indices_gpu]
class AllPool(PoolingMethod):
def __init__(self):
super().__init__()
vllm_config = get_current_vllm_config()
self.enable_chunked_prefill = (
vllm_config.scheduler_config.enable_chunked_prefill
)
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def forward_all(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor:
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with ALL pooling"
self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor
) -> PoolerOutput:
raise NotImplementedError(
"forward_all is not implemented for AllPool. Use forward instead."
)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooling_cursor = pooling_metadata.pooling_cursor
is_finished = pooling_cursor.is_finished()
hidden_states_lst = list(
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
)
return [hidden_states_lst[i] for i in pooling_cursor.index]
hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index]
if not self.enable_chunked_prefill:
return hidden_states_lst
pooling_states = pooling_metadata.pooling_states
# If chunked_prefill is enabled
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
for p, hs_chunk in zip(pooling_states, hidden_states_lst):
p.hidden_states_cache.append(hs_chunk)
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list: PoolerOutput = []
for p, finished in zip(pooling_states, is_finished):
if finished:
hidden_states_cache = p.hidden_states_cache
if len(hidden_states_cache) == 1:
output_list.append(hidden_states_cache[0])
else:
output_list.append(torch.concat(hidden_states_cache, dim=0))
p.clean()
else:
output_list.append(None)
return output_list
class MeanPool(PoolingMethod):
@@ -194,7 +232,7 @@ class MeanPool(PoolingMethod):
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor:
) -> PoolerOutput:
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with MEAN pooling"
)
@@ -399,7 +437,7 @@ class PoolerHead(nn.Module):
self,
pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
):
) -> PoolerOutput:
return self.activation(pooled_data)
@@ -418,7 +456,7 @@ class EmbeddingPoolerHead(PoolerHead):
self,
pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
):
) -> PoolerOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]
@@ -586,8 +624,12 @@ class ClassifierPooler(Pooler):
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
def forward(
self, pooled_data: torch.Tensor, pooling_param: PoolingParams
) -> torch.Tensor:
self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams
) -> PoolerOutput:
# for unfinished chunked prefill
if pooled_data is None:
return None
pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension]
@@ -630,9 +672,13 @@ class TokenClassifierPoolerHead(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
hidden_states: torch.Tensor | None,
pooling_param: PoolingParams,
) -> torch.Tensor:
) -> PoolerOutput:
# for unfinished chunked prefill
if hidden_states is None:
return None
hidden_states = hidden_states.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]
@@ -686,17 +732,20 @@ class StepPooler(Pooler):
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]:
) -> PoolerOutput:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
pooled_data = list[torch.Tensor]()
pooling_params = pooling_metadata.pooling_params
pooled_data: PoolerOutput = []
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
# for unfinished chunked prefill
if data is None:
pooled_data.append(data)
continue
step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids