[Model][6/N] Improve all pooling task | Support chunked prefill with ALL pooling (#27145)
Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -127,14 +127,14 @@ class PoolingMethod(nn.Module, ABC):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
) -> PoolerOutput:
|
||||
pooling_cursor = pooling_metadata.pooling_cursor
|
||||
return self.forward_all(hidden_states, pooling_cursor)
|
||||
|
||||
@@ -147,7 +147,7 @@ class CLSPool(PoolingMethod):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
) -> PoolerOutput:
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with CLS pooling"
|
||||
)
|
||||
@@ -163,27 +163,65 @@ class LastPool(PoolingMethod):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
) -> PoolerOutput:
|
||||
return hidden_states[pooling_cursor.last_token_indices_gpu]
|
||||
|
||||
|
||||
class AllPool(PoolingMethod):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.enable_chunked_prefill = (
|
||||
vllm_config.scheduler_config.enable_chunked_prefill
|
||||
)
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify"}
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with ALL pooling"
|
||||
self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError(
|
||||
"forward_all is not implemented for AllPool. Use forward instead."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooling_cursor = pooling_metadata.pooling_cursor
|
||||
is_finished = pooling_cursor.is_finished()
|
||||
hidden_states_lst = list(
|
||||
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
|
||||
)
|
||||
return [hidden_states_lst[i] for i in pooling_cursor.index]
|
||||
hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index]
|
||||
|
||||
if not self.enable_chunked_prefill:
|
||||
return hidden_states_lst
|
||||
|
||||
pooling_states = pooling_metadata.pooling_states
|
||||
|
||||
# If chunked_prefill is enabled
|
||||
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
|
||||
for p, hs_chunk in zip(pooling_states, hidden_states_lst):
|
||||
p.hidden_states_cache.append(hs_chunk)
|
||||
|
||||
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
|
||||
output_list: PoolerOutput = []
|
||||
for p, finished in zip(pooling_states, is_finished):
|
||||
if finished:
|
||||
hidden_states_cache = p.hidden_states_cache
|
||||
if len(hidden_states_cache) == 1:
|
||||
output_list.append(hidden_states_cache[0])
|
||||
else:
|
||||
output_list.append(torch.concat(hidden_states_cache, dim=0))
|
||||
p.clean()
|
||||
else:
|
||||
output_list.append(None)
|
||||
|
||||
return output_list
|
||||
|
||||
|
||||
class MeanPool(PoolingMethod):
|
||||
@@ -194,7 +232,7 @@ class MeanPool(PoolingMethod):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
) -> PoolerOutput:
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with MEAN pooling"
|
||||
)
|
||||
@@ -399,7 +437,7 @@ class PoolerHead(nn.Module):
|
||||
self,
|
||||
pooled_data: list[torch.Tensor] | torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
):
|
||||
) -> PoolerOutput:
|
||||
return self.activation(pooled_data)
|
||||
|
||||
|
||||
@@ -418,7 +456,7 @@ class EmbeddingPoolerHead(PoolerHead):
|
||||
self,
|
||||
pooled_data: list[torch.Tensor] | torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
):
|
||||
) -> PoolerOutput:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_dimension]
|
||||
@@ -586,8 +624,12 @@ class ClassifierPooler(Pooler):
|
||||
|
||||
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
|
||||
def forward(
|
||||
self, pooled_data: torch.Tensor, pooling_param: PoolingParams
|
||||
) -> torch.Tensor:
|
||||
self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams
|
||||
) -> PoolerOutput:
|
||||
# for unfinished chunked prefill
|
||||
if pooled_data is None:
|
||||
return None
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
# pooled_data shape: [n_tokens, hidden_dimension]
|
||||
|
||||
@@ -630,9 +672,13 @@ class TokenClassifierPoolerHead(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states: torch.Tensor | None,
|
||||
pooling_param: PoolingParams,
|
||||
) -> torch.Tensor:
|
||||
) -> PoolerOutput:
|
||||
# for unfinished chunked prefill
|
||||
if hidden_states is None:
|
||||
return None
|
||||
|
||||
hidden_states = hidden_states.to(self.head_dtype)
|
||||
# hidden_states shape: [n_token, hidden_size]
|
||||
|
||||
@@ -686,17 +732,20 @@ class StepPooler(Pooler):
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
) -> PoolerOutput:
|
||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
|
||||
|
||||
pooled_data = list[torch.Tensor]()
|
||||
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
|
||||
pooled_data: PoolerOutput = []
|
||||
for data, token_id, pooling_param in zip(
|
||||
pooled_data_lst, prompt_token_ids, pooling_params
|
||||
):
|
||||
# for unfinished chunked prefill
|
||||
if data is None:
|
||||
pooled_data.append(data)
|
||||
continue
|
||||
|
||||
step_tag_id = pooling_param.step_tag_id
|
||||
returned_token_ids = pooling_param.returned_token_ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user