[Misc] Move functions into PoolingMetadata (#30027)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -14,8 +14,6 @@ from vllm.model_executor.layers.pooler import (
|
||||
PoolerHead,
|
||||
PoolerNormalize,
|
||||
PoolingParamsUpdate,
|
||||
get_prompt_lens,
|
||||
get_prompt_token_ids,
|
||||
)
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.tasks import PoolingTask
|
||||
@@ -153,11 +151,11 @@ class GritLMMeanPool(nn.Module):
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
|
||||
prompt_lens = pooling_metadata.prompt_lens
|
||||
instr_lens = torch.tensor(
|
||||
[
|
||||
self._get_instruction_len(token_ids.cpu().numpy())
|
||||
for token_ids in get_prompt_token_ids(pooling_metadata)
|
||||
for token_ids in pooling_metadata.get_prompt_token_ids()
|
||||
],
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user