[Bugfix] Fix SHM cache initialization (#26427)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,26 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Iterable, Mapping
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.inputs.data import PromptType, TokensPrompt
|
||||
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors.interface import IOProcessor
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import IOProcessor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import Device, collect_from_async_generator, random_uuid
|
||||
from vllm.utils import Device
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -28,6 +25,11 @@ logger = init_logger(__name__)
|
||||
class EngineClient(ABC):
|
||||
"""Protocol class for Clients to Engine"""
|
||||
|
||||
vllm_config: VllmConfig
|
||||
model_config: ModelConfig
|
||||
processor: Processor
|
||||
io_processor: Optional[IOProcessor]
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_running(self) -> bool: ...
|
||||
@@ -61,180 +63,6 @@ class EngineClient(ABC):
|
||||
"""Generate outputs for a request."""
|
||||
...
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
beam_width = params.beam_width
|
||||
max_tokens = params.max_tokens
|
||||
ignore_eos = params.ignore_eos
|
||||
temperature = params.temperature
|
||||
length_penalty = params.length_penalty
|
||||
include_stop_str_in_output = params.include_stop_str_in_output
|
||||
|
||||
preprocessor = await self.get_input_preprocessor()
|
||||
tokenizer = preprocessor.get_tokenizer()
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
|
||||
|
||||
if processed_inputs["type"] == "embeds":
|
||||
raise NotImplementedError
|
||||
|
||||
# This is a workaround to fix multimodal beam search; this is a
|
||||
# bandaid fix for 2 small problems:
|
||||
# 1. Multi_modal_data on the processed_inputs currently resolves to
|
||||
# `None`.
|
||||
# 2. preprocessing above expands the multimodal placeholders. However,
|
||||
# this happens again in generation, so the double expansion causes
|
||||
# a mismatch.
|
||||
# TODO - would be ideal to handle this more gracefully.
|
||||
if isinstance(prompt, str):
|
||||
prompt_text = prompt
|
||||
prompt_token_ids = []
|
||||
multi_modal_data = None
|
||||
else:
|
||||
prompt_text = prompt.get("prompt")
|
||||
prompt_token_ids = prompt.get("prompt_token_ids", [])
|
||||
multi_modal_data = prompt.get("multi_modal_data")
|
||||
|
||||
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")
|
||||
|
||||
tokenized_length = len(prompt_token_ids)
|
||||
|
||||
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
|
||||
|
||||
beam_search_params = SamplingParams(
|
||||
logprobs=2 * beam_width,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
)
|
||||
all_beams = [
|
||||
BeamSearchSequence(
|
||||
tokens=prompt_token_ids,
|
||||
cum_logprob=0,
|
||||
logprobs=[],
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
]
|
||||
completed = []
|
||||
|
||||
for _ in range(max_tokens):
|
||||
prompts_batch, lora_req_batch = zip(
|
||||
*[
|
||||
(
|
||||
TokensPrompt(
|
||||
prompt_token_ids=beam.tokens,
|
||||
multi_modal_data=beam.multi_modal_data,
|
||||
mm_processor_kwargs=beam.mm_processor_kwargs,
|
||||
),
|
||||
beam.lora_request,
|
||||
)
|
||||
for beam in all_beams
|
||||
]
|
||||
)
|
||||
|
||||
tasks = []
|
||||
|
||||
request_id = f"beam_search-{random_uuid()}"
|
||||
for i, (individual_prompt, lora_req) in enumerate(
|
||||
zip(prompts_batch, lora_req_batch)
|
||||
):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
task = asyncio.create_task(
|
||||
collect_from_async_generator(
|
||||
self.generate(
|
||||
individual_prompt,
|
||||
beam_search_params,
|
||||
request_id_item,
|
||||
lora_request=lora_req,
|
||||
)
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
output = await asyncio.gather(*tasks)
|
||||
|
||||
output = [x[0] for x in output]
|
||||
|
||||
new_beams = []
|
||||
for i, current_beam in enumerate(all_beams):
|
||||
result = output[i]
|
||||
|
||||
if result.outputs[0].logprobs is not None:
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
if token_id == eos_token_id and not ignore_eos:
|
||||
completed.append(
|
||||
BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id]
|
||||
if include_stop_str_in_output
|
||||
else current_beam.tokens,
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
cum_logprob=current_beam.cum_logprob
|
||||
+ logprob_obj.logprob,
|
||||
finish_reason="stop",
|
||||
stop_reason=eos_token_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
new_beams.append(
|
||||
BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
lora_request=current_beam.lora_request,
|
||||
cum_logprob=current_beam.cum_logprob
|
||||
+ logprob_obj.logprob,
|
||||
multi_modal_data=current_beam.multi_modal_data,
|
||||
mm_processor_kwargs=current_beam.mm_processor_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
|
||||
all_beams = sorted_beams[:beam_width]
|
||||
|
||||
completed.extend(all_beams)
|
||||
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
|
||||
best_beams = sorted_completed[:beam_width]
|
||||
|
||||
for beam in best_beams:
|
||||
if beam.tokens[-1] == eos_token_id and not ignore_eos:
|
||||
# Skip the eos token in the text.
|
||||
tokens = beam.tokens[tokenized_length:-1]
|
||||
else:
|
||||
tokens = beam.tokens[tokenized_length:]
|
||||
beam.text = tokenizer.decode(tokens)
|
||||
|
||||
yield RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=prompt_text,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
text=beam.text,
|
||||
cumulative_logprob=beam.cum_logprob,
|
||||
token_ids=beam.tokens[tokenized_length:],
|
||||
index=i,
|
||||
logprobs=beam.logprobs,
|
||||
finish_reason=beam.finish_reason
|
||||
if beam.finish_reason is not None
|
||||
else "length",
|
||||
stop_reason=beam.stop_reason,
|
||||
)
|
||||
for (i, beam) in enumerate(best_beams)
|
||||
],
|
||||
finished=True,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def encode(
|
||||
self,
|
||||
@@ -259,29 +87,11 @@ class EngineClient(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_vllm_config(self) -> VllmConfig:
|
||||
"""Get the vllm configuration of the vLLM engine."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
"""Get the input processor of the vLLM engine."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_tokenizer(self) -> AnyTokenizer:
|
||||
"""Get the tokenizer"""
|
||||
...
|
||||
|
||||
async def get_io_processor(self) -> IOProcessor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def is_tracing_enabled(self) -> bool: ...
|
||||
|
||||
|
||||
Reference in New Issue
Block a user