[Frontend] Separate pooling APIs in offline inference (#11129)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -2,19 +2,20 @@ from array import array
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn as nn
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.attention.backends.xformers import XFormersImpl
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import PoolerHead
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
||||
PoolingTensors)
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, IntermediateTensors,
|
||||
PoolerOutput)
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
||||
PoolingSequenceGroupOutput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -52,6 +53,8 @@ class GritLMPooler(nn.Module):
|
||||
self.embed_pattern_ids = tokens_to_ids(
|
||||
["▁<", "|", "embed", "|", ">", "<0x0A>"])
|
||||
|
||||
self.head = PoolerHead(normalize=True, softmax=False)
|
||||
|
||||
def _find_array(self, arr: array, target: array, start_idx: int) -> int:
|
||||
"""
|
||||
Find the first occurrence of target in arr starting from start_idx.
|
||||
@@ -75,7 +78,7 @@ class GritLMPooler(nn.Module):
|
||||
return i
|
||||
return -1
|
||||
|
||||
def _get_instruction_len(self, prompt_token_ids: array) -> bool:
|
||||
def _get_instruction_len(self, prompt_token_ids: array) -> int:
|
||||
"""
|
||||
Get the length of the instruction in the prompt.
|
||||
|
||||
@@ -168,10 +171,10 @@ class GritLMPooler(nn.Module):
|
||||
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
|
||||
1)
|
||||
|
||||
pooled_data = nn.functional.normalize(mean_embeddings, p=2, dim=1)
|
||||
pooled_data = self.head(mean_embeddings)
|
||||
|
||||
pooled_outputs = [
|
||||
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
|
||||
PoolingSequenceGroupOutput(data) for data in pooled_data
|
||||
]
|
||||
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
||||
Reference in New Issue
Block a user