Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -17,14 +17,20 @@ from vllm.config import VllmConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
IOProcessorRequest,
|
||||
IOProcessorResponse,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
PoolingResponseData, UsageInfo)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
IOProcessorRequest,
|
||||
IOProcessorResponse,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest,
|
||||
PoolingResponse,
|
||||
PoolingResponseData,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
@@ -55,7 +61,6 @@ def _get_data(
|
||||
|
||||
|
||||
class OpenAIServingPooling(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
@@ -68,11 +73,13 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
trust_request_chat_template: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=vllm_config.model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack)
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=vllm_config.model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
@@ -110,12 +117,13 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
|
||||
if getattr(request, "dimensions", None) is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
"dimensions is currently not supported"
|
||||
)
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens)
|
||||
self.max_model_len, truncate_prompt_tokens
|
||||
)
|
||||
|
||||
if is_io_processor_request:
|
||||
if self.io_processor is None:
|
||||
@@ -123,19 +131,20 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
"to the documentation and to the "
|
||||
"'prithvi_geospatial_mae_io_processor' "
|
||||
"offline inference example for more details.")
|
||||
"offline inference example for more details."
|
||||
)
|
||||
|
||||
validated_prompt = self.io_processor.parse_request(request)
|
||||
|
||||
engine_prompts = await self.io_processor.pre_process_async(
|
||||
prompt=validated_prompt, request_id=request_id)
|
||||
prompt=validated_prompt, request_id=request_id
|
||||
)
|
||||
|
||||
elif isinstance(request, PoolingChatRequest):
|
||||
error_check_ret = self._validate_chat_template(
|
||||
request_chat_template=request.chat_template,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.
|
||||
trust_request_chat_template,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
@@ -148,8 +157,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.
|
||||
chat_template_content_format,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
# In pooling requests, we are not generating tokens,
|
||||
# so there is no need to append extra tokens to the input
|
||||
add_generation_prompt=False,
|
||||
@@ -162,8 +170,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported request of type {type(request)}")
|
||||
raise ValueError(f"Unsupported request of type {type(request)}")
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
@@ -181,13 +188,18 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request)
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
trace_headers = (
|
||||
None
|
||||
if raw_request is None
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
@@ -213,8 +225,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
)
|
||||
return self.io_processor.output_to_response(output)
|
||||
|
||||
assert isinstance(request,
|
||||
(PoolingCompletionRequest, PoolingChatRequest))
|
||||
assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest))
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
@@ -226,8 +237,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput],
|
||||
final_res_batch)
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput], final_res_batch)
|
||||
|
||||
response = self.request_output_to_pooling_response(
|
||||
final_res_batch_checked,
|
||||
@@ -278,9 +288,9 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _build_render_config(
|
||||
self, request: PoolingCompletionRequest) -> RenderConfig:
|
||||
def _build_render_config(self, request: PoolingCompletionRequest) -> RenderConfig:
|
||||
return RenderConfig(
|
||||
max_length=self.max_model_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens)
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user