[Frontend] Separate pooling APIs in offline inference (#11129)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-13 18:40:07 +08:00
committed by GitHub
parent f93bf2b189
commit eeec9e3390
21 changed files with 669 additions and 304 deletions

View File

@@ -2,19 +2,20 @@ from array import array
from typing import List, Optional, Union
import torch
from torch import nn
import torch.nn as nn
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from vllm.attention import AttentionMetadata
from vllm.attention.backends.xformers import XFormersImpl
from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolerHead
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (EmbeddingSequenceGroupOutput, IntermediateTensors,
PoolerOutput)
from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput)
logger = init_logger(__name__)
@@ -52,6 +53,8 @@ class GritLMPooler(nn.Module):
self.embed_pattern_ids = tokens_to_ids(
["▁<", "|", "embed", "|", ">", "<0x0A>"])
self.head = PoolerHead(normalize=True, softmax=False)
def _find_array(self, arr: array, target: array, start_idx: int) -> int:
"""
Find the first occurrence of target in arr starting from start_idx.
@@ -75,7 +78,7 @@ class GritLMPooler(nn.Module):
return i
return -1
def _get_instruction_len(self, prompt_token_ids: array) -> bool:
def _get_instruction_len(self, prompt_token_ids: array) -> int:
"""
Get the length of the instruction in the prompt.
@@ -168,10 +171,10 @@ class GritLMPooler(nn.Module):
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
1)
pooled_data = nn.functional.normalize(mean_embeddings, p=2, dim=1)
pooled_data = self.head(mean_embeddings)
pooled_outputs = [
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
PoolingSequenceGroupOutput(data) for data in pooled_data
]
return PoolerOutput(outputs=pooled_outputs)