[Misc] Move functions into PoolingMetadata (#30027)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-12-04 16:21:19 +08:00
committed by GitHub
parent 5430e110c0
commit 68eb5c8d97
3 changed files with 30 additions and 47 deletions

View File

@@ -14,8 +14,6 @@ from vllm.model_executor.layers.pooler import (
PoolerHead,
PoolerNormalize,
PoolingParamsUpdate,
get_prompt_lens,
get_prompt_token_ids,
)
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.tasks import PoolingTask
@@ -153,11 +151,11 @@ class GritLMMeanPool(nn.Module):
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor] | torch.Tensor:
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
prompt_lens = pooling_metadata.prompt_lens
instr_lens = torch.tensor(
[
self._get_instruction_len(token_ids.cpu().numpy())
for token_ids in get_prompt_token_ids(pooling_metadata)
for token_ids in pooling_metadata.get_prompt_token_ids()
],
device="cpu",
)