[Refactor] Consolidate sequence normalization and enc-dec parsing (#33928)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -54,6 +54,7 @@ class MockModelConfig:
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
skip_tokenizer_init = False
|
||||
is_encoder_decoder: bool = False
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
@@ -53,6 +53,7 @@ class MockModelConfig:
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
skip_tokenizer_init = False
|
||||
is_encoder_decoder: bool = False
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
@@ -52,6 +52,7 @@ class MockModelConfig:
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
skip_tokenizer_init: bool = False
|
||||
is_encoder_decoder: bool = False
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
@@ -529,6 +529,7 @@ class MockModelConfig:
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
skip_tokenizer_init: bool = False
|
||||
is_encoder_decoder: bool = False
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
0
tests/renderers/inputs/__init__.py
Normal file
0
tests/renderers/inputs/__init__.py
Normal file
41
tests/renderers/inputs/test_preprocess.py
Normal file
41
tests/renderers/inputs/test_preprocess.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.renderers.inputs.preprocess import prompt_to_seq
|
||||
|
||||
|
||||
def test_empty_input():
|
||||
assert prompt_to_seq([]) == []
|
||||
assert prompt_to_seq([[]]) == [[]]
|
||||
assert prompt_to_seq([[], []]) == [[], []]
|
||||
|
||||
|
||||
def test_text_input():
|
||||
assert prompt_to_seq("foo") == ["foo"]
|
||||
assert prompt_to_seq(["foo"]) == ["foo"]
|
||||
assert prompt_to_seq(["foo", "bar"]) == ["foo", "bar"]
|
||||
|
||||
|
||||
def test_token_input():
|
||||
assert prompt_to_seq([1, 2]) == [[1, 2]]
|
||||
assert prompt_to_seq([[1, 2]]) == [[1, 2]]
|
||||
assert prompt_to_seq([[1, 2], [3, 4]]) == [[1, 2], [3, 4]]
|
||||
|
||||
|
||||
def test_text_token_input():
|
||||
assert prompt_to_seq([[1, 2], "foo"]) == [[1, 2], "foo"]
|
||||
assert prompt_to_seq(["foo", [1, 2]]) == ["foo", [1, 2]]
|
||||
|
||||
|
||||
def test_bytes_input():
|
||||
assert prompt_to_seq(b"foo") == [b"foo"]
|
||||
assert prompt_to_seq([b"foo"]) == [b"foo"]
|
||||
assert prompt_to_seq([b"foo", b"bar"]) == [b"foo", b"bar"]
|
||||
|
||||
|
||||
def test_dict_input():
|
||||
assert prompt_to_seq({"prompt": "foo"}) == [{"prompt": "foo"}]
|
||||
assert prompt_to_seq([{"prompt": "foo"}]) == [{"prompt": "foo"}]
|
||||
assert prompt_to_seq([{"prompt": "foo"}, {"prompt_token_ids": [1, 2]}]) == [
|
||||
{"prompt": "foo"},
|
||||
{"prompt_token_ids": [1, 2]},
|
||||
]
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import io
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
@@ -9,8 +10,11 @@ import pybase64
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import SingletonPrompt
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
@@ -33,6 +37,7 @@ class MockModelConfig:
|
||||
encoder_config: dict[str, Any] | None = None
|
||||
enable_prompt_embeds: bool = True
|
||||
skip_tokenizer_init: bool = False
|
||||
is_encoder_decoder: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -80,65 +85,34 @@ def _build_renderer(
|
||||
return renderer
|
||||
|
||||
|
||||
def _preprocess_prompt(
|
||||
mdoel_config: ModelConfig,
|
||||
prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
|
||||
):
|
||||
return [
|
||||
(
|
||||
prompt
|
||||
if isinstance(prompt, bytes)
|
||||
else parse_model_prompt(mdoel_config, prompt)
|
||||
)
|
||||
for prompt in prompt_to_seq(prompt_or_prompts)
|
||||
]
|
||||
|
||||
|
||||
class TestValidatePrompt:
|
||||
STRING_INPUTS = [
|
||||
"",
|
||||
"foo",
|
||||
"foo bar",
|
||||
"foo baz bar",
|
||||
"foo bar qux baz",
|
||||
]
|
||||
|
||||
TOKEN_INPUTS = [
|
||||
[-1],
|
||||
[1],
|
||||
[1, 2],
|
||||
[1, 3, 4],
|
||||
[1, 2, 4, 3],
|
||||
]
|
||||
|
||||
INPUTS_SLICES = [
|
||||
slice(None, None, -1),
|
||||
slice(None, None, 2),
|
||||
slice(None, None, -2),
|
||||
]
|
||||
|
||||
# Test that a nested mixed-type list of lists raises a TypeError.
|
||||
def test_empty_input(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
with pytest.raises(ValueError, match="at least one prompt"):
|
||||
renderer.render_completions([])
|
||||
renderer.render_prompts(_preprocess_prompt(renderer.config, []))
|
||||
|
||||
def test_invalid_type(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
with pytest.raises(TypeError, match="string or an array of tokens"):
|
||||
renderer.render_completions([[1, 2], ["foo", "bar"]])
|
||||
|
||||
@pytest.mark.parametrize("string_input", STRING_INPUTS)
|
||||
def test_string_consistent(self, string_input: str):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
assert renderer.render_completions(string_input) == renderer.render_completions(
|
||||
[string_input]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
|
||||
def test_token_consistent(self, token_input: list[int]):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
assert renderer.render_completions(token_input) == renderer.render_completions(
|
||||
[token_input]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
|
||||
def test_string_slice(self, inputs_slice: slice):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
assert renderer.render_completions(self.STRING_INPUTS)[
|
||||
inputs_slice
|
||||
] == renderer.render_completions(self.STRING_INPUTS[inputs_slice])
|
||||
with pytest.raises(TypeError, match="should be a list of integers"):
|
||||
renderer.render_prompts(
|
||||
_preprocess_prompt(renderer.config, [[1, 2], ["foo", "bar"]]) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
class TestRenderPrompt:
|
||||
@@ -146,7 +120,7 @@ class TestRenderPrompt:
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
tokens = [101, 7592, 2088]
|
||||
prompts = renderer.render_completions(tokens)
|
||||
prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens))
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
@@ -159,7 +133,9 @@ class TestRenderPrompt:
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
|
||||
prompts = renderer.render_completions(token_lists)
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(renderer.config, token_lists)
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
@@ -174,7 +150,9 @@ class TestRenderPrompt:
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
text_input = "x" * 10
|
||||
prompts = renderer.render_completions(text_input)
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(renderer.config, text_input)
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
@@ -187,7 +165,9 @@ class TestRenderPrompt:
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
text_list_input = ["x" * 10, "x" * 12, "x" * 14]
|
||||
prompts = renderer.render_completions(text_list_input)
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(renderer.config, text_list_input)
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
@@ -200,7 +180,9 @@ class TestRenderPrompt:
|
||||
def test_zero_truncation(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
prompts = renderer.render_completions("x" * 200)
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(renderer.config, "x" * 200)
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0),
|
||||
@@ -212,7 +194,9 @@ class TestRenderPrompt:
|
||||
def test_pos_truncation(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
prompts = renderer.render_completions("x" * 200)
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(renderer.config, "x" * 200)
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50),
|
||||
@@ -224,7 +208,9 @@ class TestRenderPrompt:
|
||||
def test_neg_truncation(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
prompts = renderer.render_completions("x" * 200)
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(renderer.config, "x" * 200)
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1),
|
||||
@@ -237,7 +223,9 @@ class TestRenderPrompt:
|
||||
renderer = _build_renderer(MockModelConfig(), truncation_side="left")
|
||||
|
||||
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
|
||||
prompts = renderer.render_completions(long_tokens)
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(renderer.config, long_tokens)
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
|
||||
@@ -251,7 +239,9 @@ class TestRenderPrompt:
|
||||
renderer = _build_renderer(MockModelConfig(), truncation_side="right")
|
||||
|
||||
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
|
||||
prompts = renderer.render_completions(long_tokens)
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(renderer.config, long_tokens)
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
|
||||
@@ -266,7 +256,9 @@ class TestRenderPrompt:
|
||||
|
||||
# Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
|
||||
long_tokens = "x" * 150
|
||||
prompts = renderer.render_completions(long_tokens)
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(renderer.config, long_tokens)
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
@@ -285,7 +277,9 @@ class TestRenderPrompt:
|
||||
|
||||
# Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
|
||||
long_tokens = "x" * 150
|
||||
prompts = renderer.render_completions(long_tokens)
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(renderer.config, long_tokens)
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
@@ -304,7 +298,9 @@ class TestRenderPrompt:
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
long_tokens = list(range(150)) # Exceeds max_total_tokens=100
|
||||
prompts = renderer.render_completions(long_tokens)
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(renderer.config, long_tokens)
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
@@ -318,7 +314,9 @@ class TestRenderPrompt:
|
||||
def test_no_tokenizer_for_text(self):
|
||||
renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
|
||||
|
||||
prompts = renderer.render_completions("Hello world")
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(renderer.config, "Hello world")
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
|
||||
renderer.tokenize_prompts(
|
||||
@@ -330,7 +328,7 @@ class TestRenderPrompt:
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
tokens = [1, 2, 3, 4]
|
||||
prompts = renderer.render_completions(tokens)
|
||||
prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens))
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(
|
||||
@@ -359,7 +357,9 @@ class TestRenderEmbedPrompt:
|
||||
tensor_input = torch.randn(10, 768, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(tensor_input)
|
||||
|
||||
prompts = renderer.render_completions(prompt_embeds=embed_bytes)
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(renderer.config, embed_bytes)
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
@@ -377,8 +377,11 @@ class TestRenderEmbedPrompt:
|
||||
torch.randn(12, 512, dtype=torch.float32),
|
||||
]
|
||||
|
||||
prompts = renderer.render_completions(
|
||||
prompt_embeds=[self._create_test_embed_bytes(t) for t in tensor_inputs],
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(
|
||||
renderer.config,
|
||||
[self._create_test_embed_bytes(t) for t in tensor_inputs],
|
||||
)
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
@@ -395,8 +398,10 @@ class TestRenderEmbedPrompt:
|
||||
# Create tensor with more tokens than truncation limit
|
||||
tensor_input = torch.randn(20, 768, dtype=torch.float32)
|
||||
|
||||
prompts = renderer.render_completions(
|
||||
prompt_embeds=self._create_test_embed_bytes(tensor_input),
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(
|
||||
renderer.config, self._create_test_embed_bytes(tensor_input)
|
||||
)
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
@@ -420,8 +425,10 @@ class TestRenderEmbedPrompt:
|
||||
for dtype in dtypes:
|
||||
tensor_input = torch.randn(5, 256, dtype=dtype)
|
||||
|
||||
prompts = renderer.render_completions(
|
||||
prompt_embeds=self._create_test_embed_bytes(tensor_input),
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(
|
||||
renderer.config, self._create_test_embed_bytes(tensor_input)
|
||||
)
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
@@ -437,8 +444,10 @@ class TestRenderEmbedPrompt:
|
||||
# Test tensor with batch dimension gets squeezed
|
||||
tensor_input = torch.randn(1, 10, 768, dtype=torch.float32)
|
||||
|
||||
prompts = renderer.render_completions(
|
||||
prompt_embeds=self._create_test_embed_bytes(tensor_input),
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(
|
||||
renderer.config, self._create_test_embed_bytes(tensor_input)
|
||||
)
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
@@ -455,9 +464,11 @@ class TestRenderEmbedPrompt:
|
||||
text_input = "Hello world"
|
||||
tensor_input = torch.randn(5, 256, dtype=torch.float32)
|
||||
|
||||
prompts = renderer.render_completions(
|
||||
text_input,
|
||||
prompt_embeds=self._create_test_embed_bytes(tensor_input),
|
||||
prompts = renderer.render_prompts(
|
||||
_preprocess_prompt(
|
||||
renderer.config,
|
||||
[text_input, self._create_test_embed_bytes(tensor_input)],
|
||||
)
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
@@ -465,8 +476,8 @@ class TestRenderEmbedPrompt:
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
# First should be embed prompt
|
||||
assert torch.equal(results[0]["prompt_embeds"], tensor_input)
|
||||
# Second should be tokens prompt
|
||||
assert "prompt_token_ids" in results[1]
|
||||
assert len(results[1]["prompt_token_ids"]) == len(text_input)
|
||||
# First should be tokens prompt
|
||||
assert "prompt_token_ids" in results[0]
|
||||
assert len(results[0]["prompt_token_ids"]) == len(text_input)
|
||||
# Second should be embed prompt
|
||||
assert torch.equal(results[1]["prompt_embeds"], tensor_input)
|
||||
|
||||
@@ -3,16 +3,40 @@
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.renderers import ChatParams
|
||||
from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockHFConfig:
|
||||
model_type: str = "any"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
runner_type = "generate"
|
||||
model: str = MODEL_NAME
|
||||
tokenizer: str = MODEL_NAME
|
||||
trust_remote_code: bool = False
|
||||
max_model_len: int = 100
|
||||
tokenizer_revision = None
|
||||
tokenizer_mode = "mistral"
|
||||
hf_config = MockHFConfig()
|
||||
encoder_config: dict[str, Any] | None = None
|
||||
enable_prompt_embeds: bool = True
|
||||
skip_tokenizer_init: bool = False
|
||||
is_encoder_decoder: bool = False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_mistral_tokenizer_does_not_block_event_loop():
|
||||
@@ -23,9 +47,10 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop():
|
||||
time.sleep(2)
|
||||
return expected_tokens
|
||||
|
||||
mock_model_config = MockModelConfig(skip_tokenizer_init=True)
|
||||
mock_tokenizer = Mock(spec=MistralTokenizer)
|
||||
mock_tokenizer.apply_chat_template = mocked_apply_chat_template
|
||||
mock_renderer = MistralRenderer(Mock(spec=ModelConfig), tokenizer_kwargs={})
|
||||
mock_renderer = MistralRenderer(mock_model_config, tokenizer_kwargs={})
|
||||
mock_renderer._tokenizer = mock_tokenizer
|
||||
|
||||
task = mock_renderer.render_messages_async([], ChatParams())
|
||||
|
||||
@@ -4,52 +4,13 @@
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import zip_enc_dec_prompts
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mm_processor_kwargs,expected_mm_kwargs",
|
||||
[
|
||||
(None, [{}, {}]),
|
||||
({}, [{}, {}]),
|
||||
({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
|
||||
([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
|
||||
],
|
||||
)
|
||||
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
|
||||
"""Test mm_processor_kwargs init for zipping enc/dec prompts."""
|
||||
encoder_prompts = ["An encoder prompt", "Another encoder prompt"]
|
||||
decoder_prompts = ["A decoder prompt", "Another decoder prompt"]
|
||||
zipped_prompts = zip_enc_dec_prompts(
|
||||
encoder_prompts, decoder_prompts, mm_processor_kwargs
|
||||
)
|
||||
assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
|
||||
for enc, dec, exp_kwargs, zipped in zip(
|
||||
encoder_prompts, decoder_prompts, expected_mm_kwargs, zipped_prompts
|
||||
):
|
||||
assert isinstance(zipped, dict)
|
||||
assert len(zipped.keys()) == 3
|
||||
assert zipped["encoder_prompt"] == enc
|
||||
assert zipped["decoder_prompt"] == dec
|
||||
assert zipped["mm_processor_kwargs"] == exp_kwargs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_id",
|
||||
[
|
||||
"facebook/chameleon-7b",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"prompt",
|
||||
[
|
||||
"",
|
||||
{"prompt_token_ids": []},
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model_id", ["facebook/chameleon-7b"])
|
||||
@pytest.mark.parametrize("prompt", ["", {"prompt_token_ids": []}])
|
||||
@pytest.mark.skip(
|
||||
reason=(
|
||||
"Applying huggingface processor on text inputs results in "
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import IOProcessor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.renderers.inputs import DictPrompt, TokPrompt
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
@@ -53,7 +54,11 @@ class EngineClient(ABC):
|
||||
@abstractmethod
|
||||
def generate(
|
||||
self,
|
||||
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
|
||||
prompt: EngineCoreRequest
|
||||
| PromptType
|
||||
| DictPrompt
|
||||
| TokPrompt
|
||||
| AsyncGenerator[StreamingInput, None],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
*,
|
||||
@@ -70,7 +75,7 @@ class EngineClient(ABC):
|
||||
@abstractmethod
|
||||
def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
prompt: PromptType | DictPrompt | TokPrompt,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: LoRARequest | None = None,
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import itertools
|
||||
import warnings
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import cloudpickle
|
||||
import torch.nn as nn
|
||||
@@ -53,16 +53,13 @@ from vllm.entrypoints.pooling.score.utils import (
|
||||
validate_score_input,
|
||||
)
|
||||
from vllm.entrypoints.utils import log_non_default_args
|
||||
from vllm.inputs import (
|
||||
from vllm.inputs.data import (
|
||||
DataPrompt,
|
||||
EmbedsPrompt,
|
||||
ExplicitEncoderDecoderPrompt,
|
||||
PromptType,
|
||||
SingletonPrompt,
|
||||
TextPrompt,
|
||||
TokensPrompt,
|
||||
)
|
||||
from vllm.inputs.parse import get_prompt_components, is_explicit_encoder_decoder_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@@ -76,6 +73,13 @@ from vllm.outputs import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
|
||||
from vllm.renderers.inputs import DictPrompt, SingletonDictPrompt, TokPrompt
|
||||
from vllm.renderers.inputs.preprocess import (
|
||||
conversation_to_seq,
|
||||
extract_prompt_components,
|
||||
parse_model_prompt,
|
||||
prompt_to_seq,
|
||||
)
|
||||
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
@@ -93,9 +97,6 @@ logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
EnginePrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
|
||||
EngineEncDecPrompt: TypeAlias = ExplicitEncoderDecoderPrompt[EnginePrompt, EnginePrompt]
|
||||
|
||||
|
||||
class LLM:
|
||||
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||
@@ -445,21 +446,20 @@ class LLM:
|
||||
if sampling_params is None:
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
|
||||
self._validate_and_add_requests(
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts,
|
||||
params=sampling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=self._get_modality_specific_lora_reqs(prompts, lora_request),
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
def _get_modality_specific_lora_reqs(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
prompts: Sequence[DictPrompt | TokPrompt],
|
||||
lora_request: list[LoRARequest] | LoRARequest | None,
|
||||
):
|
||||
# Grab the lora config off the vllm config on the engine,
|
||||
@@ -475,9 +475,6 @@ class LLM:
|
||||
):
|
||||
return lora_request
|
||||
|
||||
if not isinstance(prompts, Sequence) or isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
optional_loras = (
|
||||
[lora_request] * len(prompts)
|
||||
if not isinstance(lora_request, Sequence)
|
||||
@@ -495,14 +492,12 @@ class LLM:
|
||||
|
||||
def _resolve_single_prompt_mm_lora(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
prompt: DictPrompt | TokPrompt,
|
||||
lora_request: LoRARequest | None,
|
||||
default_mm_loras: dict[str, str] | None,
|
||||
):
|
||||
if (
|
||||
not default_mm_loras
|
||||
or not isinstance(prompt, dict)
|
||||
or not (mm_data := prompt.get("multi_modal_data") or {})
|
||||
if not default_mm_loras or not (
|
||||
mm_data := prompt.get("multi_modal_data") or {}
|
||||
):
|
||||
return lora_request
|
||||
|
||||
@@ -806,61 +801,11 @@ class LLM:
|
||||
add_special_tokens=not model_config.is_encoder_decoder,
|
||||
).with_kwargs(tokenization_kwargs)
|
||||
|
||||
def _normalize_prompts(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
) -> list[EnginePrompt | EngineEncDecPrompt]:
|
||||
if isinstance(prompts, str):
|
||||
prompts = TextPrompt(prompt=prompts)
|
||||
|
||||
return prompts if isinstance(prompts, Sequence) else [prompts] # type: ignore[return-value]
|
||||
|
||||
def _preprocess_cmpl_singleton(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
tok_params: TokenizeParams,
|
||||
*,
|
||||
tokenize: bool,
|
||||
) -> EnginePrompt:
|
||||
renderer = self.llm_engine.renderer
|
||||
|
||||
if not isinstance(prompt, dict):
|
||||
prompt = renderer.render_completion(prompt)
|
||||
|
||||
return renderer.tokenize_prompt(prompt, tok_params) if tokenize else prompt
|
||||
|
||||
def _preprocess_cmpl_enc_dec(
|
||||
self,
|
||||
prompt: ExplicitEncoderDecoderPrompt,
|
||||
tok_params: TokenizeParams,
|
||||
) -> EngineEncDecPrompt:
|
||||
enc_prompt = prompt["encoder_prompt"]
|
||||
dec_prompt = prompt["decoder_prompt"]
|
||||
|
||||
return EngineEncDecPrompt(
|
||||
encoder_prompt=self._preprocess_cmpl_singleton(
|
||||
enc_prompt,
|
||||
tok_params,
|
||||
# TODO: Move multi-modal processor into tokenization
|
||||
tokenize=not self.model_config.is_multimodal_model,
|
||||
),
|
||||
decoder_prompt=(
|
||||
None
|
||||
if dec_prompt is None
|
||||
else self._preprocess_cmpl_singleton(
|
||||
dec_prompt,
|
||||
tok_params,
|
||||
# TODO: Move multi-modal processor into tokenization
|
||||
tokenize=not self.model_config.is_multimodal_model,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def _preprocess_completion(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
prompts: Sequence[PromptType],
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> list[EnginePrompt | EngineEncDecPrompt]:
|
||||
) -> list[DictPrompt | TokPrompt]:
|
||||
"""
|
||||
Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
|
||||
a format that can be passed to `_add_request`.
|
||||
@@ -871,32 +816,26 @@ class LLM:
|
||||
A list of `TokensPrompts` objects containing the tokenized prompt
|
||||
after chat template interpolation, and the raw multi-modal inputs.
|
||||
"""
|
||||
renderer = self.llm_engine.renderer
|
||||
model_config = self.model_config
|
||||
|
||||
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
|
||||
|
||||
engine_prompts = list[EnginePrompt | EngineEncDecPrompt]()
|
||||
for prompt in self._normalize_prompts(prompts):
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
engine_prompts.append(self._preprocess_cmpl_enc_dec(prompt, tok_params))
|
||||
else:
|
||||
# Some MM models have non-default `add_special_tokens`
|
||||
# TODO: Move multi-modal processor into tokenization
|
||||
engine_prompts.append(
|
||||
self._preprocess_cmpl_singleton(
|
||||
prompt,
|
||||
tok_params,
|
||||
tokenize=not self.model_config.is_multimodal_model,
|
||||
)
|
||||
)
|
||||
engine_prompts = list[DictPrompt | TokPrompt]()
|
||||
for prompt in prompts:
|
||||
parsed_prompt = parse_model_prompt(model_config, prompt)
|
||||
in_prompt = renderer.render_prompt(parsed_prompt)
|
||||
|
||||
# Some MM models have non-default `add_special_tokens`
|
||||
# TODO: Move multi-modal processor into tokenization
|
||||
engine_prompts.append(
|
||||
in_prompt
|
||||
if model_config.is_multimodal_model
|
||||
else renderer.tokenize_prompt(in_prompt, tok_params)
|
||||
)
|
||||
|
||||
return engine_prompts
|
||||
|
||||
def _normalize_conversations(
|
||||
self,
|
||||
conversations: list[ChatCompletionMessageParam]
|
||||
| list[list[ChatCompletionMessageParam]],
|
||||
) -> list[list[ChatCompletionMessageParam]]:
|
||||
return conversations if is_list_of(conversations, list) else [conversations] # type: ignore[list-item,return-value]
|
||||
|
||||
def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
|
||||
model_config = self.model_config
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
@@ -909,8 +848,7 @@ class LLM:
|
||||
|
||||
def _preprocess_chat(
|
||||
self,
|
||||
conversations: list[ChatCompletionMessageParam]
|
||||
| list[list[ChatCompletionMessageParam]],
|
||||
conversations: Sequence[list[ChatCompletionMessageParam]],
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
chat_template_kwargs: dict[str, Any] | None = None,
|
||||
@@ -919,7 +857,7 @@ class LLM:
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
mm_processor_kwargs: dict[str, Any] | None = None,
|
||||
) -> list[EnginePrompt]:
|
||||
) -> list[DictPrompt | TokPrompt]:
|
||||
"""
|
||||
Convert a list of conversations into prompts so that they can then
|
||||
be used as input for other LLM APIs.
|
||||
@@ -947,11 +885,14 @@ class LLM:
|
||||
)
|
||||
tok_params = self._get_chat_tok_params(tokenization_kwargs)
|
||||
|
||||
engine_prompts = list[EnginePrompt]()
|
||||
for conversation in self._normalize_conversations(conversations):
|
||||
engine_prompts = list[DictPrompt | TokPrompt]()
|
||||
for conversation in conversations:
|
||||
_, in_prompt = renderer.render_messages(conversation, chat_params)
|
||||
if mm_processor_kwargs is not None:
|
||||
in_prompt["mm_processor_kwargs"] = mm_processor_kwargs
|
||||
target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
|
||||
"encoder_prompt", in_prompt
|
||||
)
|
||||
target_prompt["mm_processor_kwargs"] = mm_processor_kwargs # type: ignore
|
||||
|
||||
engine_prompts.append(renderer.tokenize_prompt(in_prompt, tok_params))
|
||||
|
||||
@@ -960,8 +901,8 @@ class LLM:
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
| list[list[ChatCompletionMessageParam]],
|
||||
sampling_params: SamplingParams | list[SamplingParams] | None = None,
|
||||
| Sequence[list[ChatCompletionMessageParam]],
|
||||
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: LoRARequest | None = None,
|
||||
chat_template: str | None = None,
|
||||
@@ -984,7 +925,7 @@ class LLM:
|
||||
to the OpenAI API.
|
||||
|
||||
Args:
|
||||
messages: A list of conversations or a single conversation.
|
||||
messages: A sequence of conversations or a single conversation.
|
||||
|
||||
- Each conversation is represented as a list of messages.
|
||||
- Each message is a dictionary with 'role' and 'content' keys.
|
||||
@@ -1023,8 +964,23 @@ class LLM:
|
||||
A list of `RequestOutput` objects containing the generated
|
||||
responses in the same order as the input messages.
|
||||
"""
|
||||
prompts = self._preprocess_chat(
|
||||
messages,
|
||||
model_config = self.model_config
|
||||
runner_type = model_config.runner_type
|
||||
if runner_type != "generate":
|
||||
raise ValueError(
|
||||
"LLM.chat() is only supported for generative models. "
|
||||
"Try passing `--runner generate` to use the model as a "
|
||||
"generative model."
|
||||
)
|
||||
|
||||
if sampling_params is None:
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
|
||||
outputs = self._run_chat(
|
||||
messages=messages,
|
||||
params=sampling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
@@ -1035,13 +991,7 @@ class LLM:
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
return self.generate(
|
||||
prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
@@ -1163,7 +1113,7 @@ class LLM:
|
||||
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
|
||||
raise ValueError(msg)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts,
|
||||
params=pooling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
@@ -1171,8 +1121,6 @@ class LLM:
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
|
||||
model_outputs = self.engine_class.validate_outputs(
|
||||
outputs, PoolingRequestOutput
|
||||
)
|
||||
@@ -1523,14 +1471,13 @@ class LLM:
|
||||
|
||||
prompts.append(engine_prompt)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts,
|
||||
params=pooling_params_list,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
|
||||
|
||||
return [ScoringRequestOutput.from_base(item) for item in items]
|
||||
@@ -1727,33 +1674,29 @@ class LLM:
|
||||
"""
|
||||
return self.llm_engine.get_metrics()
|
||||
|
||||
def _validate_and_add_requests(
|
||||
def _params_to_seq(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
params: SamplingParams
|
||||
| Sequence[SamplingParams]
|
||||
| PoolingParams
|
||||
| Sequence[PoolingParams],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
priority: list[int] | None = None,
|
||||
) -> None:
|
||||
in_prompts = self._normalize_prompts(prompts)
|
||||
num_requests = len(in_prompts)
|
||||
|
||||
| Sequence[SamplingParams | PoolingParams],
|
||||
num_requests: int,
|
||||
) -> Sequence[SamplingParams | PoolingParams]:
|
||||
if isinstance(params, Sequence):
|
||||
if len(params) != num_requests:
|
||||
raise ValueError(
|
||||
f"The lengths of prompts ({params}) "
|
||||
f"and lora_request ({len(params)}) must be the same."
|
||||
f"and params ({len(params)}) must be the same."
|
||||
)
|
||||
|
||||
engine_params = params
|
||||
else:
|
||||
engine_params = [params] * num_requests
|
||||
return params
|
||||
|
||||
return [params] * num_requests
|
||||
|
||||
def _lora_request_to_seq(
|
||||
self,
|
||||
lora_request: LoRARequest | None | Sequence[LoRARequest | None],
|
||||
num_requests: int,
|
||||
) -> Sequence[LoRARequest | None]:
|
||||
if isinstance(lora_request, Sequence):
|
||||
if len(lora_request) != num_requests:
|
||||
raise ValueError(
|
||||
@@ -1761,28 +1704,50 @@ class LLM:
|
||||
f"and lora_request ({len(lora_request)}) must be the same."
|
||||
)
|
||||
|
||||
engine_lora_requests: Sequence[LoRARequest | None] = lora_request
|
||||
else:
|
||||
engine_lora_requests = [lora_request] * num_requests
|
||||
return lora_request
|
||||
|
||||
return [lora_request] * num_requests
|
||||
|
||||
def _priority_to_seq(
|
||||
self,
|
||||
priority: list[int] | None,
|
||||
num_requests: int,
|
||||
) -> Sequence[int]:
|
||||
if priority is not None:
|
||||
if len(priority) != num_requests:
|
||||
raise ValueError(
|
||||
f"The lengths of prompts ({num_requests}) "
|
||||
f"and priority ({len(priority)}) must be the same."
|
||||
)
|
||||
else:
|
||||
priority = [0] * num_requests
|
||||
|
||||
if any(param.truncate_prompt_tokens is not None for param in engine_params):
|
||||
return priority
|
||||
|
||||
return [0] * num_requests
|
||||
|
||||
def _run_completion(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
params: SamplingParams
|
||||
| PoolingParams
|
||||
| Sequence[SamplingParams | PoolingParams],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
priority: list[int] | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
seq_prompts = prompt_to_seq(prompts)
|
||||
seq_params = self._params_to_seq(params, len(seq_prompts))
|
||||
|
||||
if any(param.truncate_prompt_tokens is not None for param in seq_params):
|
||||
# TODO: Remove this after deprecating `param.truncate_prompt_tokens`
|
||||
# Then, move the code from the `else` block to the top and let
|
||||
# `self._preprocess_completion` handle prompt normalization
|
||||
engine_prompts = [
|
||||
engine_prompt
|
||||
for in_prompt, param in zip(in_prompts, engine_params)
|
||||
for prompt, param in zip(seq_prompts, seq_params)
|
||||
for engine_prompt in self._preprocess_completion(
|
||||
[in_prompt],
|
||||
[prompt],
|
||||
tokenization_kwargs=merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
|
||||
@@ -1791,17 +1756,90 @@ class LLM:
|
||||
]
|
||||
else:
|
||||
engine_prompts = self._preprocess_completion(
|
||||
in_prompts,
|
||||
seq_prompts,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
for sp in engine_params:
|
||||
self._validate_and_add_requests(
|
||||
prompts=engine_prompts,
|
||||
params=seq_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=self._get_modality_specific_lora_reqs(
|
||||
engine_prompts, lora_request
|
||||
),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
return self._run_engine(use_tqdm=use_tqdm)
|
||||
|
||||
def _run_chat(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
| Sequence[list[ChatCompletionMessageParam]],
|
||||
params: SamplingParams
|
||||
| PoolingParams
|
||||
| Sequence[SamplingParams | PoolingParams],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: LoRARequest | None = None,
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
add_generation_prompt: bool = True,
|
||||
continue_final_message: bool = False,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
chat_template_kwargs: dict[str, Any] | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
mm_processor_kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
engine_prompts = self._preprocess_chat(
|
||||
conversation_to_seq(messages),
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tools,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=engine_prompts,
|
||||
params=params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=self._get_modality_specific_lora_reqs(
|
||||
engine_prompts, lora_request
|
||||
),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
return self._run_engine(use_tqdm=use_tqdm)
|
||||
|
||||
def _validate_and_add_requests(
|
||||
self,
|
||||
prompts: Sequence[DictPrompt | TokPrompt],
|
||||
params: SamplingParams
|
||||
| PoolingParams
|
||||
| Sequence[SamplingParams | PoolingParams],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
priority: list[int] | None = None,
|
||||
) -> None:
|
||||
num_requests = len(prompts)
|
||||
seq_params = self._params_to_seq(params, num_requests)
|
||||
seq_lora_requests = self._lora_request_to_seq(lora_request, num_requests)
|
||||
seq_priority = self._priority_to_seq(priority, num_requests)
|
||||
|
||||
for sp in seq_params:
|
||||
if isinstance(sp, SamplingParams):
|
||||
# We only care about the final output
|
||||
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
# Add requests to the engine.
|
||||
it = engine_prompts
|
||||
it = prompts
|
||||
if use_tqdm:
|
||||
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
|
||||
it = tqdm_func(it, desc="Adding requests")
|
||||
@@ -1812,10 +1850,10 @@ class LLM:
|
||||
for i, prompt in enumerate(it):
|
||||
request_id = self._add_request(
|
||||
prompt,
|
||||
engine_params[i],
|
||||
lora_request=engine_lora_requests[i],
|
||||
seq_params[i],
|
||||
lora_request=seq_lora_requests[i],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority[i],
|
||||
priority=seq_priority[i],
|
||||
)
|
||||
added_request_ids.append(request_id)
|
||||
except Exception as e:
|
||||
@@ -1825,13 +1863,13 @@ class LLM:
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
prompt: PromptType | DictPrompt | TokPrompt,
|
||||
params: SamplingParams | PoolingParams,
|
||||
lora_request: LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
priority: int = 0,
|
||||
) -> str:
|
||||
prompt_text, _, _ = get_prompt_components(prompt)
|
||||
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
|
||||
request_id = str(next(self.request_counter))
|
||||
|
||||
if params.truncate_prompt_tokens is not None:
|
||||
|
||||
@@ -67,12 +67,13 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
)
|
||||
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
|
||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.parser import ParserManager
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import (
|
||||
@@ -218,10 +219,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
async def render_chat_request(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> (
|
||||
tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]
|
||||
| ErrorResponse
|
||||
):
|
||||
) -> tuple[list[ConversationMessage], list[TokPrompt]] | ErrorResponse:
|
||||
"""
|
||||
render chat request by validating and preprocessing inputs.
|
||||
|
||||
@@ -380,7 +378,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
prompt_text = engine_prompt.get("prompt")
|
||||
prompt_text = self._extract_prompt_text(engine_prompt)
|
||||
|
||||
# If we are creating sub requests for multiple prompts, ensure that they
|
||||
# have unique request ids.
|
||||
@@ -389,10 +387,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
)
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len=self.max_model_len,
|
||||
request=request,
|
||||
prompt=engine_prompt,
|
||||
default_sampling_params=self.default_sampling_params,
|
||||
self.max_model_len,
|
||||
request,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
sampling_params: SamplingParams | BeamSearchParams
|
||||
|
||||
@@ -34,10 +34,10 @@ from vllm.entrypoints.openai.engine.serving import (
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
@@ -78,7 +78,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
async def render_completion_request(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
) -> list[TokensPrompt | EmbedsPrompt] | ErrorResponse:
|
||||
) -> list[TokPrompt] | ErrorResponse:
|
||||
"""
|
||||
render completion request by validating and preprocessing inputs.
|
||||
|
||||
@@ -160,13 +160,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
prompt_text = engine_prompt.get("prompt")
|
||||
prompt_text = self._extract_prompt_text(engine_prompt)
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len=self.max_model_len,
|
||||
request=request,
|
||||
prompt=engine_prompt,
|
||||
default_sampling_params=self.default_sampling_params,
|
||||
self.max_model_len,
|
||||
request,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
sampling_params: SamplingParams | BeamSearchParams
|
||||
@@ -277,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
# with the inputs token IDs
|
||||
if final_res.prompt is None:
|
||||
engine_prompt = engine_prompts[i]
|
||||
final_res.prompt = engine_prompt.get("prompt")
|
||||
final_res.prompt = self._extract_prompt_text(engine_prompt)
|
||||
|
||||
final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
|
||||
|
||||
@@ -313,7 +313,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
async def completion_stream_generator(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
engine_prompts: list[TokensPrompt | EmbedsPrompt],
|
||||
engine_prompts: list[TokPrompt],
|
||||
result_generator: AsyncIterator[tuple[int, RequestOutput]],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
@@ -347,7 +347,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
prompt_text = res.prompt
|
||||
if prompt_text is None:
|
||||
engine_prompt = engine_prompts[prompt_idx]
|
||||
prompt_text = engine_prompt.get("prompt")
|
||||
prompt_text = self._extract_prompt_text(engine_prompt)
|
||||
|
||||
# Prompt details are excluded from later streamed outputs
|
||||
if prompt_token_ids is not None:
|
||||
|
||||
@@ -96,11 +96,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
)
|
||||
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import EmbedsPrompt, PromptType, TokensPrompt
|
||||
from vllm.inputs.parse import (
|
||||
get_prompt_components,
|
||||
is_explicit_encoder_decoder_prompt,
|
||||
)
|
||||
from vllm.inputs.data import PromptType, SingletonPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob, PromptLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -108,6 +104,14 @@ from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
from vllm.renderers.inputs.preprocess import (
|
||||
SingletonDictPrompt,
|
||||
extract_prompt_components,
|
||||
extract_prompt_len,
|
||||
parse_model_prompt,
|
||||
prompt_to_seq,
|
||||
)
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser
|
||||
@@ -203,7 +207,7 @@ class ServeContext(Generic[RequestT]):
|
||||
request_id: str
|
||||
created_time: int = field(default_factory=lambda: int(time.time()))
|
||||
lora_request: LoRARequest | None = None
|
||||
engine_prompts: list[TokensPrompt | EmbedsPrompt] | None = None
|
||||
engine_prompts: list[TokPrompt] | None = None
|
||||
|
||||
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
|
||||
None
|
||||
@@ -247,7 +251,7 @@ class OpenAIServing:
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
prompt: TokPrompt,
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
lora_request: LoRARequest | None = None,
|
||||
@@ -271,20 +275,12 @@ class OpenAIServing:
|
||||
|
||||
eos_token_id: int = tokenizer.eos_token_id # type: ignore
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
raise NotImplementedError
|
||||
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
|
||||
raise NotImplementedError("Encoder-decoder prompt not supported")
|
||||
|
||||
prompt_text: str | None
|
||||
prompt_token_ids: list[int]
|
||||
multi_modal_data: MultiModalDataDict | None
|
||||
if isinstance(prompt, str):
|
||||
prompt_text = prompt
|
||||
prompt_token_ids = []
|
||||
multi_modal_data = None
|
||||
else:
|
||||
prompt_text = prompt.get("prompt") # type: ignore
|
||||
prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore
|
||||
multi_modal_data = prompt.get("multi_modal_data") # type: ignore
|
||||
prompt_text: str | None = prompt.get("prompt") # type: ignore
|
||||
prompt_token_ids: list[int] = prompt.get("prompt_token_ids", []) # type: ignore
|
||||
multi_modal_data: MultiModalDataDict | None = prompt.get("multi_modal_data") # type: ignore
|
||||
|
||||
mm_processor_kwargs: dict[str, Any] | None = None
|
||||
|
||||
@@ -963,22 +959,40 @@ class OpenAIServing:
|
||||
request: RendererRequest,
|
||||
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
|
||||
prompt_embeds: bytes | list[bytes] | None,
|
||||
) -> list[TokensPrompt | EmbedsPrompt]:
|
||||
) -> list[TokPrompt]:
|
||||
renderer = self.renderer
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
model_config = self.model_config
|
||||
|
||||
in_prompts = await renderer.render_completions_async(
|
||||
prompt_input, prompt_embeds
|
||||
)
|
||||
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)
|
||||
tok_params = request.build_tok_params(model_config)
|
||||
|
||||
prompts = list[SingletonPrompt | bytes]()
|
||||
if prompt_embeds is not None: # embeds take higher priority
|
||||
prompts.extend(prompt_to_seq(prompt_embeds))
|
||||
if prompt_input is not None:
|
||||
prompts.extend(prompt_to_seq(prompt_input))
|
||||
|
||||
parsed_prompts = [
|
||||
(
|
||||
prompt
|
||||
if isinstance(prompt, bytes)
|
||||
else parse_model_prompt(model_config, prompt)
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
in_prompts = await renderer.render_prompts_async(parsed_prompts)
|
||||
|
||||
extra_items = {
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
}
|
||||
for prompt in engine_prompts:
|
||||
prompt.update(extra_items) # type: ignore
|
||||
for in_prompt in in_prompts:
|
||||
target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
|
||||
"encoder_prompt", in_prompt
|
||||
)
|
||||
target_prompt.update(extra_items) # type: ignore
|
||||
|
||||
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)
|
||||
|
||||
return engine_prompts
|
||||
|
||||
@@ -991,7 +1005,7 @@ class OpenAIServing:
|
||||
default_template_kwargs: dict[str, Any] | None,
|
||||
tool_dicts: list[dict[str, Any]] | None = None,
|
||||
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
|
||||
) -> tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]:
|
||||
) -> tuple[list[ConversationMessage], list[TokPrompt]]:
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
renderer = self.renderer
|
||||
@@ -1009,17 +1023,21 @@ class OpenAIServing:
|
||||
default_template, default_template_content_format
|
||||
).with_defaults(default_template_kwargs)
|
||||
|
||||
conversation, prompt = await renderer.render_messages_async(
|
||||
conversation, in_prompt = await renderer.render_messages_async(
|
||||
messages, chat_params
|
||||
)
|
||||
engine_prompt = await renderer.tokenize_prompt_async(prompt, tok_params)
|
||||
target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
|
||||
"encoder_prompt", in_prompt
|
||||
)
|
||||
|
||||
extra_items = {
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
}
|
||||
engine_prompt.update(extra_items) # type: ignore
|
||||
target_prompt.update(extra_items) # type: ignore
|
||||
|
||||
engine_prompt = await renderer.tokenize_prompt_async(target_prompt, tok_params)
|
||||
|
||||
# tool parsing is done only if a tool_parser has been set and if
|
||||
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
|
||||
@@ -1040,6 +1058,15 @@ class OpenAIServing:
|
||||
|
||||
return conversation, [engine_prompt]
|
||||
|
||||
def _extract_prompt_components(self, prompt: object):
|
||||
return extract_prompt_components(self.model_config, prompt)
|
||||
|
||||
def _extract_prompt_text(self, prompt: object):
|
||||
return self._extract_prompt_components(prompt).text
|
||||
|
||||
def _extract_prompt_len(self, prompt: object):
|
||||
return extract_prompt_len(self.model_config, prompt)
|
||||
|
||||
async def _render_next_turn(
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
@@ -1067,7 +1094,7 @@ class OpenAIServing:
|
||||
async def _generate_with_builtin_tools(
|
||||
self,
|
||||
request_id: str,
|
||||
engine_prompt: TokensPrompt | EmbedsPrompt,
|
||||
engine_prompt: TokPrompt,
|
||||
sampling_params: SamplingParams,
|
||||
tok_params: TokenizeParams,
|
||||
context: ConversationContext,
|
||||
@@ -1075,7 +1102,7 @@ class OpenAIServing:
|
||||
priority: int = 0,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
prompt_text = engine_prompt.get("prompt")
|
||||
prompt_text = self._extract_prompt_text(engine_prompt)
|
||||
|
||||
orig_priority = priority
|
||||
sub_request = 0
|
||||
@@ -1145,12 +1172,12 @@ class OpenAIServing:
|
||||
context.chat_template_content_format,
|
||||
)
|
||||
engine_prompt = engine_prompts[0]
|
||||
prompt_text = engine_prompt.get("prompt")
|
||||
prompt_text = self._extract_prompt_text(engine_prompt)
|
||||
|
||||
sampling_params.max_tokens = get_max_tokens(
|
||||
self.max_model_len,
|
||||
context.request,
|
||||
engine_prompt,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self.default_sampling_params, # type: ignore
|
||||
)
|
||||
|
||||
@@ -1161,20 +1188,20 @@ class OpenAIServing:
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptType,
|
||||
inputs: PromptType | TokPrompt,
|
||||
params: SamplingParams | PoolingParams | BeamSearchParams | None,
|
||||
lora_request: LoRARequest | None,
|
||||
) -> None:
|
||||
if self.request_logger is None:
|
||||
return
|
||||
|
||||
prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs)
|
||||
components = self._extract_prompt_components(inputs)
|
||||
|
||||
self.request_logger.log_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
prompt_token_ids,
|
||||
prompt_embeds,
|
||||
components.text,
|
||||
components.token_ids,
|
||||
components.embeds,
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
@@ -116,13 +116,13 @@ from vllm.entrypoints.openai.responses.utils import (
|
||||
)
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import get_prompt_len
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob as SampleLogprob
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.parser import ParserManager
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils import random_uuid
|
||||
@@ -292,10 +292,10 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
def _validate_generator_input(
|
||||
self,
|
||||
engine_prompt: TokensPrompt | EmbedsPrompt,
|
||||
engine_prompt: TokPrompt,
|
||||
) -> ErrorResponse | None:
|
||||
"""Add validations to the input to the generator here."""
|
||||
prompt_len = get_prompt_len(engine_prompt)
|
||||
prompt_len = self._extract_prompt_len(engine_prompt)
|
||||
if self.max_model_len <= prompt_len:
|
||||
error_message = (
|
||||
f"The engine prompt length {prompt_len} "
|
||||
@@ -442,7 +442,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
default_max_tokens = get_max_tokens(
|
||||
self.max_model_len,
|
||||
request,
|
||||
engine_prompt,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import time
|
||||
import zlib
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from functools import cached_property
|
||||
from typing import Literal, TypeAlias, TypeVar, cast
|
||||
from typing import Final, Literal, TypeAlias, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
@@ -37,12 +37,13 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranslationStreamResponse,
|
||||
)
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType
|
||||
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import FlatLogprobs, Logprob
|
||||
from vllm.model_executor.models import SupportsTranscription, supports_transcription
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.renderers.inputs import EncoderDecoderDictPrompt
|
||||
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
@@ -94,7 +95,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
)
|
||||
|
||||
self.default_sampling_params = self.model_config.get_diff_sampling_param()
|
||||
self.task_type = task_type
|
||||
self.task_type: Final = task_type
|
||||
|
||||
self.asr_config = self.model_cls.get_speech_to_text_config(
|
||||
self.model_config, task_type
|
||||
@@ -298,35 +299,26 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
to_language=to_language,
|
||||
)
|
||||
if request.response_format == "verbose_json":
|
||||
if not is_explicit_encoder_decoder_prompt(prompt):
|
||||
raise VLLMValidationError(
|
||||
"Expected prompt to be an encoder-decoder prompt",
|
||||
parameter="prompt",
|
||||
value=type(prompt).__name__,
|
||||
)
|
||||
|
||||
prompt = self._preprocess_verbose_prompt(prompt)
|
||||
prompt = self._preprocess_verbose_prompt(parse_enc_dec_prompt(prompt))
|
||||
|
||||
prompts.append(prompt)
|
||||
|
||||
return prompts, duration
|
||||
|
||||
def _repl_verbose_text(self, text: str):
|
||||
return text.replace("<|notimestamps|>", "<|0.00|>")
|
||||
|
||||
def _preprocess_verbose_prompt(self, prompt: ExplicitEncoderDecoderPrompt):
|
||||
def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt):
|
||||
dec_prompt = prompt["decoder_prompt"]
|
||||
|
||||
if isinstance(dec_prompt, str):
|
||||
prompt["decoder_prompt"] = self._repl_verbose_text(dec_prompt)
|
||||
elif isinstance(dec_prompt, dict) and "prompt" in dec_prompt:
|
||||
dec_prompt["prompt"] = self._repl_verbose_text(dec_prompt["prompt"])
|
||||
else:
|
||||
if not (isinstance(dec_prompt, dict) and "prompt" in dec_prompt):
|
||||
raise VLLMValidationError(
|
||||
"Expected decoder_prompt to contain text",
|
||||
parameter="decoder_prompt",
|
||||
value=type(dec_prompt).__name__,
|
||||
)
|
||||
|
||||
dec_prompt["prompt"] = dec_prompt["prompt"].replace(
|
||||
"<|notimestamps|>", "<|0.00|>"
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
def _get_verbose_segments(
|
||||
|
||||
@@ -28,10 +28,11 @@ from vllm.entrypoints.pooling.utils import (
|
||||
encode_pooling_output_base64,
|
||||
encode_pooling_output_float,
|
||||
)
|
||||
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
from vllm.utils.collection_utils import chunk_list
|
||||
from vllm.utils.serial_utils import EmbedDType, Endianness
|
||||
@@ -369,7 +370,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
async def _create_single_prompt_generator(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
engine_prompt: TokensPrompt | EmbedsPrompt,
|
||||
engine_prompt: TokPrompt,
|
||||
pooling_params: PoolingParams,
|
||||
trace_headers: Mapping[str, str] | None,
|
||||
prompt_index: int,
|
||||
|
||||
@@ -33,8 +33,11 @@ from vllm.entrypoints.pooling.utils import (
|
||||
encode_pooling_output_base64,
|
||||
encode_pooling_output_float,
|
||||
)
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
from vllm.renderers.inputs.preprocess import prompt_to_seq
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
|
||||
|
||||
@@ -91,6 +94,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
"dimensions is currently not supported"
|
||||
)
|
||||
|
||||
engine_prompts: Sequence[PromptType | TokPrompt]
|
||||
if is_io_processor_request:
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
@@ -102,14 +106,10 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
|
||||
validated_prompt = self.io_processor.parse_request(request)
|
||||
|
||||
engine_prompts = await self.io_processor.pre_process_async(
|
||||
raw_prompts = await self.io_processor.pre_process_async(
|
||||
prompt=validated_prompt, request_id=request_id
|
||||
)
|
||||
if not isinstance(engine_prompts, Sequence) or isinstance(
|
||||
engine_prompts, (str, bytes, bytearray)
|
||||
):
|
||||
engine_prompts = [engine_prompts]
|
||||
|
||||
engine_prompts = prompt_to_seq(raw_prompts)
|
||||
elif isinstance(request, PoolingChatRequest):
|
||||
error_check_ret = self._validate_chat_template(
|
||||
request_chat_template=request.chat_template,
|
||||
|
||||
@@ -17,8 +17,6 @@ from starlette.background import BackgroundTask, BackgroundTasks
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import EmbedsPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import get_prompt_len
|
||||
from vllm.logger import current_formatter_type, init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
@@ -189,7 +187,7 @@ def cli_env_setup():
|
||||
def get_max_tokens(
|
||||
max_model_len: int,
|
||||
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
|
||||
prompt: TokensPrompt | EmbedsPrompt,
|
||||
input_length: int,
|
||||
default_sampling_params: dict,
|
||||
) -> int:
|
||||
# NOTE: Avoid isinstance() for better efficiency
|
||||
@@ -204,7 +202,6 @@ def get_max_tokens(
|
||||
# CompletionRequest (also a fallback for ChatCompletionRequest)
|
||||
max_tokens = getattr(request, "max_tokens", None)
|
||||
|
||||
input_length = get_prompt_len(prompt)
|
||||
default_max_tokens = max_model_len - input_length
|
||||
max_output_tokens = current_platform.get_max_output_tokens(input_length)
|
||||
|
||||
|
||||
@@ -16,11 +16,8 @@ from .data import (
|
||||
TextPrompt,
|
||||
TokenInputs,
|
||||
TokensPrompt,
|
||||
build_explicit_enc_dec_prompt,
|
||||
embeds_inputs,
|
||||
to_enc_dec_tuple_list,
|
||||
token_inputs,
|
||||
zip_enc_dec_prompts,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -39,8 +36,5 @@ __all__ = [
|
||||
"EncoderDecoderInputs",
|
||||
"ProcessorInputs",
|
||||
"SingletonInputs",
|
||||
"build_explicit_enc_dec_prompt",
|
||||
"to_enc_dec_tuple_list",
|
||||
"zip_enc_dec_prompts",
|
||||
"StreamingInput",
|
||||
]
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
from typing_extensions import NotRequired, TypedDict, TypeVar
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
@@ -23,7 +22,13 @@ else:
|
||||
MultiModalUUIDDict = object
|
||||
|
||||
|
||||
class _CommonKeys(TypedDict):
|
||||
# Inputs to LLM API
|
||||
class _PromptOptions(TypedDict):
|
||||
"""
|
||||
Additional options available to all
|
||||
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt].
|
||||
"""
|
||||
|
||||
multi_modal_data: NotRequired[MultiModalDataDict | None]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
@@ -53,14 +58,14 @@ class _CommonKeys(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
class TextPrompt(_CommonKeys):
|
||||
class TextPrompt(_PromptOptions):
|
||||
"""Schema for a text prompt."""
|
||||
|
||||
prompt: str
|
||||
"""The input text to be tokenized before passing to the model."""
|
||||
|
||||
|
||||
class TokensPrompt(_CommonKeys):
|
||||
class TokensPrompt(_PromptOptions):
|
||||
"""Schema for a tokenized prompt."""
|
||||
|
||||
prompt_token_ids: list[int]
|
||||
@@ -73,7 +78,7 @@ class TokensPrompt(_CommonKeys):
|
||||
"""A list of token type IDs to pass to the cross encoder model."""
|
||||
|
||||
|
||||
class EmbedsPrompt(_CommonKeys):
|
||||
class EmbedsPrompt(_PromptOptions):
|
||||
"""Schema for a prompt provided via token embeddings."""
|
||||
|
||||
prompt_embeds: torch.Tensor
|
||||
@@ -83,93 +88,113 @@ class EmbedsPrompt(_CommonKeys):
|
||||
"""The prompt text corresponding to the token embeddings, if available."""
|
||||
|
||||
|
||||
class DataPrompt(_CommonKeys):
|
||||
"""Represents generic inputs handled by IO processor plugins."""
|
||||
DecoderOnlyPrompt: TypeAlias = (
|
||||
str | TextPrompt | list[int] | TokensPrompt | EmbedsPrompt
|
||||
)
|
||||
"""
|
||||
Schema of a prompt for a decoder-only model:
|
||||
|
||||
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
|
||||
- A tokenized prompt (list of token IDs, or
|
||||
[`TokensPrompt`][vllm.inputs.data.TokensPrompt])
|
||||
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
|
||||
|
||||
For encoder-decoder models, passing a singleton prompt is shorthand for passing
|
||||
`ExplicitEncoderDecoderPrompt(encoder_prompt=prompt, decoder_prompt=None)`.
|
||||
"""
|
||||
|
||||
|
||||
EncoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt
|
||||
"""
|
||||
Schema of a prompt for the encoder part of a encoder-decoder model:
|
||||
|
||||
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
|
||||
- A tokenized prompt (list of token IDs, or
|
||||
[`TokensPrompt`][vllm.inputs.data.TokensPrompt])
|
||||
"""
|
||||
|
||||
|
||||
DecoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt
|
||||
"""
|
||||
Schema of a prompt for the decoder part of an encoder-decoder model:
|
||||
|
||||
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
|
||||
- A tokenized prompt (list of token IDs, or
|
||||
[`TokensPrompt`][vllm.inputs.data.TokensPrompt])
|
||||
|
||||
Note:
|
||||
Multi-modal inputs are not supported for decoder prompts.
|
||||
"""
|
||||
|
||||
|
||||
class ExplicitEncoderDecoderPrompt(TypedDict):
|
||||
"""
|
||||
Schema for a pair of encoder and decoder singleton prompts.
|
||||
|
||||
Note:
|
||||
This schema is not valid for decoder-only models.
|
||||
"""
|
||||
|
||||
encoder_prompt: EncoderPrompt
|
||||
"""The prompt for the encoder part of the model."""
|
||||
|
||||
decoder_prompt: DecoderPrompt | None
|
||||
"""
|
||||
The prompt for the decoder part of the model.
|
||||
|
||||
Passing `None` will cause the prompt to be inferred automatically.
|
||||
"""
|
||||
|
||||
|
||||
EncoderDecoderPrompt: TypeAlias = EncoderPrompt | ExplicitEncoderDecoderPrompt
|
||||
"""
|
||||
Schema for a prompt for an encoder-decoder model.
|
||||
|
||||
You can pass a singleton encoder prompt, in which case the decoder prompt is
|
||||
considered to be `None` (i.e., infer automatically).
|
||||
"""
|
||||
|
||||
|
||||
SingletonPrompt: TypeAlias = DecoderOnlyPrompt | EncoderPrompt | DecoderPrompt
|
||||
"""
|
||||
Schema for a single prompt. This is as opposed to a data structure
|
||||
which encapsulates multiple prompts, such as
|
||||
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt].
|
||||
"""
|
||||
|
||||
|
||||
PromptType: TypeAlias = DecoderOnlyPrompt | EncoderDecoderPrompt
|
||||
"""
|
||||
Schema for any prompt, regardless of model type.
|
||||
|
||||
This is the input format accepted by most [`LLM`][vllm.entrypoints.llm.LLM] APIs.
|
||||
"""
|
||||
|
||||
|
||||
class DataPrompt(_PromptOptions):
|
||||
"""
|
||||
Represents generic inputs that are converted to
|
||||
[`PromptType`][vllm.inputs.data.PromptType] by IO processor plugins.
|
||||
"""
|
||||
|
||||
data: Any
|
||||
"""The input data"""
|
||||
"""The input data."""
|
||||
|
||||
data_format: str
|
||||
"""The input data format"""
|
||||
"""The input data format."""
|
||||
|
||||
|
||||
SingletonPrompt: TypeAlias = str | TextPrompt | TokensPrompt | EmbedsPrompt
|
||||
"""
|
||||
Set of possible schemas for a single prompt:
|
||||
|
||||
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
|
||||
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
|
||||
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
|
||||
|
||||
Note that "singleton" is as opposed to a data structure
|
||||
which encapsulates multiple prompts, i.e. of the sort
|
||||
which may be utilized for encoder/decoder models when
|
||||
the user desires to express both the encoder & decoder
|
||||
prompts explicitly, i.e.
|
||||
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
|
||||
|
||||
A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be
|
||||
employed as (1) input to a decoder-only model, (2) input to
|
||||
the encoder of an encoder/decoder model, in the scenario
|
||||
where the decoder-prompt is not specified explicitly, or
|
||||
(3) as a member of a larger data structure encapsulating
|
||||
more than one prompt, i.e.
|
||||
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
|
||||
"""
|
||||
|
||||
|
||||
_T1_co = TypeVar(
|
||||
"_T1_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
|
||||
)
|
||||
_T2_co = TypeVar(
|
||||
"_T2_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
|
||||
)
|
||||
|
||||
|
||||
# TODO: Make fields ReadOnly once mypy supports it
|
||||
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
|
||||
# Outputs of processor
|
||||
class _InputOptions(TypedDict):
|
||||
"""
|
||||
Represents an encoder/decoder model input prompt,
|
||||
comprising an explicit encoder prompt and a decoder prompt.
|
||||
|
||||
The encoder and decoder prompts, respectively, may be formatted
|
||||
according to any of the
|
||||
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas,
|
||||
and are not required to have the same schema.
|
||||
|
||||
Only the encoder prompt may have multi-modal data. mm_processor_kwargs
|
||||
should be at the top-level, and should not be set in the encoder/decoder
|
||||
prompts, since they are agnostic to the encoder/decoder.
|
||||
|
||||
Note that an
|
||||
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
|
||||
may not be used as an input to a decoder-only model,
|
||||
and that the `encoder_prompt` and `decoder_prompt`
|
||||
fields of this data structure themselves must be
|
||||
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] instances.
|
||||
Additional options available to all input types.
|
||||
"""
|
||||
|
||||
encoder_prompt: _T1_co
|
||||
|
||||
decoder_prompt: _T2_co | None
|
||||
|
||||
mm_processor_kwargs: NotRequired[dict[str, Any]]
|
||||
cache_salt: NotRequired[str]
|
||||
"""Optional cache salt to be used for prefix caching."""
|
||||
|
||||
|
||||
PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt[Any, Any]
|
||||
"""
|
||||
Set of possible schemas for an LLM input, including
|
||||
both decoder-only and encoder/decoder input types:
|
||||
|
||||
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
|
||||
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
|
||||
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
|
||||
- A single data structure containing both an encoder and a decoder prompt
|
||||
([`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt])
|
||||
"""
|
||||
|
||||
|
||||
class TokenInputs(TypedDict):
|
||||
class TokenInputs(_InputOptions):
|
||||
"""Represents token-based inputs."""
|
||||
|
||||
type: Literal["token"]
|
||||
@@ -178,11 +203,6 @@ class TokenInputs(TypedDict):
|
||||
prompt_token_ids: list[int]
|
||||
"""The token IDs of the prompt."""
|
||||
|
||||
cache_salt: NotRequired[str]
|
||||
"""
|
||||
Optional cache salt to be used for prefix caching.
|
||||
"""
|
||||
|
||||
|
||||
def token_inputs(
|
||||
prompt_token_ids: list[int],
|
||||
@@ -198,7 +218,7 @@ def token_inputs(
|
||||
return inputs
|
||||
|
||||
|
||||
class EmbedsInputs(TypedDict):
|
||||
class EmbedsInputs(_InputOptions):
|
||||
"""Represents embeddings-based inputs."""
|
||||
|
||||
type: Literal["embeds"]
|
||||
@@ -207,11 +227,6 @@ class EmbedsInputs(TypedDict):
|
||||
prompt_embeds: torch.Tensor
|
||||
"""The embeddings of the prompt."""
|
||||
|
||||
cache_salt: NotRequired[str]
|
||||
"""
|
||||
Optional cache salt to be used for prefix caching.
|
||||
"""
|
||||
|
||||
|
||||
def embeds_inputs(
|
||||
prompt_embeds: torch.Tensor,
|
||||
@@ -229,96 +244,60 @@ def embeds_inputs(
|
||||
|
||||
DecoderOnlyInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
|
||||
"""
|
||||
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they are
|
||||
passed to the model executor.
|
||||
This specifies the data required for decoder-only models.
|
||||
A processed prompt from
|
||||
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
|
||||
which can be passed to
|
||||
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
|
||||
for decoder-only models.
|
||||
"""
|
||||
|
||||
|
||||
EncoderInputs: TypeAlias = TokenInputs | MultiModalEncDecInputs
|
||||
"""
|
||||
A processed encoder prompt from
|
||||
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
|
||||
which can be passed to
|
||||
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
|
||||
for encoder-decoder models.
|
||||
"""
|
||||
|
||||
|
||||
DecoderInputs: TypeAlias = TokenInputs | MultiModalInputs
|
||||
"""
|
||||
A processed decoder prompt from
|
||||
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
|
||||
which can be passed to
|
||||
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
|
||||
for encoder-decoder models.
|
||||
"""
|
||||
|
||||
|
||||
class EncoderDecoderInputs(TypedDict):
|
||||
"""
|
||||
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they
|
||||
are passed to the model executor.
|
||||
|
||||
This specifies the required data for encoder-decoder models.
|
||||
A processed pair of encoder and decoder singleton prompts.
|
||||
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
|
||||
which can be passed to
|
||||
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
|
||||
for encoder-decoder models.
|
||||
"""
|
||||
|
||||
encoder: TokenInputs | MultiModalEncDecInputs
|
||||
encoder: EncoderInputs
|
||||
"""The inputs for the encoder portion."""
|
||||
|
||||
decoder: TokenInputs | MultiModalInputs
|
||||
decoder: DecoderInputs
|
||||
"""The inputs for the decoder portion."""
|
||||
|
||||
|
||||
SingletonInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
|
||||
"""
|
||||
A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be
|
||||
passed to [`Sequence`][collections.abc.Sequence].
|
||||
"""
|
||||
|
||||
ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs
|
||||
"""
|
||||
The outputs from [`vllm.inputs.preprocess.InputPreprocessor`][].
|
||||
A processed prompt from
|
||||
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
|
||||
which can be passed to
|
||||
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor].
|
||||
"""
|
||||
|
||||
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
|
||||
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
|
||||
|
||||
|
||||
def build_explicit_enc_dec_prompt(
|
||||
encoder_prompt: _T1,
|
||||
decoder_prompt: _T2 | None,
|
||||
mm_processor_kwargs: dict[str, Any] | None = None,
|
||||
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
return ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=encoder_prompt,
|
||||
decoder_prompt=decoder_prompt,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def zip_enc_dec_prompts(
|
||||
enc_prompts: Iterable[_T1],
|
||||
dec_prompts: Iterable[_T2 | None],
|
||||
mm_processor_kwargs: Iterable[dict[str, Any]] | dict[str, Any] | None = None,
|
||||
) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
|
||||
"""
|
||||
Zip encoder and decoder prompts together into a list of
|
||||
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
|
||||
instances.
|
||||
|
||||
`mm_processor_kwargs` may also be provided; if a dict is passed, the same
|
||||
dictionary will be used for every encoder/decoder prompt. If an iterable is
|
||||
provided, it will be zipped with the encoder/decoder prompts.
|
||||
"""
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = cast(dict[str, Any], {})
|
||||
if isinstance(mm_processor_kwargs, dict):
|
||||
return [
|
||||
build_explicit_enc_dec_prompt(
|
||||
encoder_prompt,
|
||||
decoder_prompt,
|
||||
cast(dict[str, Any], mm_processor_kwargs),
|
||||
)
|
||||
for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
|
||||
]
|
||||
return [
|
||||
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, mm_proc_kwargs)
|
||||
for (encoder_prompt, decoder_prompt, mm_proc_kwargs) in zip(
|
||||
enc_prompts, dec_prompts, mm_processor_kwargs
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def to_enc_dec_tuple_list(
|
||||
enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
|
||||
) -> list[tuple[_T1, _T2 | None]]:
|
||||
return [
|
||||
(enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"])
|
||||
for enc_dec_prompt in enc_dec_prompts
|
||||
]
|
||||
SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,88 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
|
||||
from .data import (
|
||||
EmbedsPrompt,
|
||||
ExplicitEncoderDecoderPrompt,
|
||||
ProcessorInputs,
|
||||
PromptType,
|
||||
SingletonInputs,
|
||||
SingletonPrompt,
|
||||
TextPrompt,
|
||||
TokensPrompt,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
||||
class ParsedStrPrompt(TypedDict):
|
||||
type: Literal["str"]
|
||||
content: str
|
||||
|
||||
|
||||
class ParsedTextPrompt(TypedDict):
|
||||
type: Literal["text"]
|
||||
content: TextPrompt
|
||||
|
||||
|
||||
class ParsedTokensPrompt(TypedDict):
|
||||
type: Literal["tokens"]
|
||||
content: TokensPrompt
|
||||
|
||||
|
||||
class ParsedEmbedsPrompt(TypedDict):
|
||||
type: Literal["embeds"]
|
||||
content: EmbedsPrompt
|
||||
|
||||
|
||||
ParsedSingletonPrompt: TypeAlias = (
|
||||
ParsedStrPrompt | ParsedTextPrompt | ParsedTokensPrompt | ParsedEmbedsPrompt
|
||||
)
|
||||
|
||||
|
||||
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
|
||||
if isinstance(prompt, str):
|
||||
return ParsedStrPrompt(type="str", content=prompt)
|
||||
elif isinstance(prompt, dict):
|
||||
# Type ignores are because mypy does not correctly infer the TypedDicts
|
||||
# Pyright does succeed.
|
||||
if "prompt_embeds" in prompt:
|
||||
return ParsedEmbedsPrompt(type="embeds", content=prompt) # type: ignore[typeddict-item]
|
||||
elif "prompt_token_ids" in prompt:
|
||||
return ParsedTokensPrompt(type="tokens", content=prompt) # type: ignore[typeddict-item]
|
||||
elif "prompt" in prompt:
|
||||
return ParsedTextPrompt(type="text", content=prompt)
|
||||
raise TypeError(
|
||||
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt"
|
||||
)
|
||||
|
||||
|
||||
def is_explicit_encoder_decoder_prompt(
|
||||
prompt: PromptType,
|
||||
) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
||||
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
||||
|
||||
|
||||
def split_enc_dec_prompt(
|
||||
prompt: PromptType,
|
||||
) -> tuple[SingletonPrompt, SingletonPrompt | None]:
|
||||
if isinstance(prompt, str):
|
||||
return prompt, None
|
||||
|
||||
if "encoder_prompt" in prompt and "decoder_prompt" in prompt:
|
||||
# NOTE: This passes pyright but not mypy
|
||||
return (
|
||||
prompt["encoder_prompt"], # type: ignore[typeddict-item]
|
||||
prompt["decoder_prompt"], # type: ignore[typeddict-item]
|
||||
)
|
||||
|
||||
return prompt, None
|
||||
from .data import ProcessorInputs, SingletonInputs
|
||||
|
||||
|
||||
def split_enc_dec_inputs(
|
||||
@@ -96,30 +15,3 @@ def split_enc_dec_inputs(
|
||||
)
|
||||
|
||||
return None, inputs
|
||||
|
||||
|
||||
class PromptComponents(NamedTuple):
|
||||
text: str | None = None
|
||||
token_ids: list[int] | None = None
|
||||
embeds: "torch.Tensor | None" = None
|
||||
|
||||
|
||||
def get_prompt_components(prompt: PromptType) -> PromptComponents:
|
||||
if isinstance(prompt, str):
|
||||
return PromptComponents(text=prompt)
|
||||
|
||||
if encoder_prompt := prompt.get("encoder_prompt"):
|
||||
return get_prompt_components(encoder_prompt) # type: ignore[arg-type]
|
||||
|
||||
return PromptComponents(
|
||||
text=prompt.get("prompt"), # type: ignore[arg-type]
|
||||
token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type]
|
||||
embeds=prompt.get("prompt_embeds"),
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_len(prompt: TokensPrompt | EmbedsPrompt):
|
||||
return length_from_prompt_token_ids_or_embeds(
|
||||
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
|
||||
prompt.get("prompt_embeds"), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
@@ -2,43 +2,51 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, overload
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig, ObservabilityConfig
|
||||
from vllm.inputs.parse import split_enc_dec_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalEncDecInputs,
|
||||
MultiModalInputs,
|
||||
MultiModalUUIDDict,
|
||||
)
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.renderers import renderer_from_config
|
||||
from vllm.renderers.inputs import (
|
||||
DecoderDictPrompt,
|
||||
DecoderOnlyDictPrompt,
|
||||
DictPrompt,
|
||||
EncoderDecoderDictPrompt,
|
||||
EncoderDictPrompt,
|
||||
SingletonDictPrompt,
|
||||
TokPrompt,
|
||||
)
|
||||
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.jsontree import json_iter_leaves
|
||||
from vllm.v1.metrics.stats import MultiModalCacheStats
|
||||
|
||||
from .data import (
|
||||
DecoderInputs,
|
||||
DecoderOnlyInputs,
|
||||
EmbedsInputs,
|
||||
EmbedsPrompt,
|
||||
EncoderDecoderInputs,
|
||||
EncoderInputs,
|
||||
ProcessorInputs,
|
||||
PromptType,
|
||||
SingletonInputs,
|
||||
SingletonPrompt,
|
||||
TextPrompt,
|
||||
TokenInputs,
|
||||
TokensPrompt,
|
||||
embeds_inputs,
|
||||
token_inputs,
|
||||
)
|
||||
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -328,9 +336,36 @@ class InputPreprocessor:
|
||||
|
||||
return inputs
|
||||
|
||||
@overload
|
||||
def _prompt_to_llm_inputs(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
prompt: EncoderDictPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> EncoderInputs: ...
|
||||
|
||||
@overload
|
||||
def _prompt_to_llm_inputs( # type: ignore[misc]
|
||||
self,
|
||||
prompt: DecoderDictPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> DecoderInputs: ...
|
||||
|
||||
@overload
|
||||
def _prompt_to_llm_inputs( # type: ignore[misc]
|
||||
self,
|
||||
prompt: DecoderOnlyDictPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> DecoderOnlyInputs: ...
|
||||
|
||||
def _prompt_to_llm_inputs(
|
||||
self,
|
||||
prompt: SingletonDictPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
@@ -346,34 +381,25 @@ class InputPreprocessor:
|
||||
|
||||
* [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
|
||||
"""
|
||||
parsed = parse_singleton_prompt(prompt)
|
||||
if "prompt_embeds" in prompt:
|
||||
return self._process_embeds(prompt) # type: ignore[arg-type]
|
||||
|
||||
if parsed["type"] == "embeds":
|
||||
return self._process_embeds(parsed["content"])
|
||||
if parsed["type"] == "tokens":
|
||||
if "prompt_token_ids" in prompt:
|
||||
return self._process_tokens(
|
||||
parsed["content"],
|
||||
prompt, # type: ignore[arg-type]
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
if parsed["type"] == "text":
|
||||
|
||||
if "prompt" in prompt:
|
||||
return self._process_text(
|
||||
parsed["content"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
if parsed["type"] == "str":
|
||||
return self._process_text(
|
||||
TextPrompt(prompt=parsed["content"]),
|
||||
prompt, # type: ignore[arg-type]
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
assert_never(parsed)
|
||||
assert_never(prompt) # type: ignore[arg-type]
|
||||
|
||||
def _validate_enc_inputs(
|
||||
self,
|
||||
inputs: SingletonInputs,
|
||||
) -> TokenInputs | MultiModalEncDecInputs:
|
||||
def _validate_enc_inputs(self, inputs: SingletonInputs) -> EncoderInputs:
|
||||
if inputs["type"] == "embeds":
|
||||
raise ValueError(
|
||||
"Embedding inputs are not supported for encoder-decoder models"
|
||||
@@ -387,10 +413,7 @@ class InputPreprocessor:
|
||||
|
||||
return inputs # type: ignore[return-value]
|
||||
|
||||
def _validate_dec_inputs(
|
||||
self,
|
||||
inputs: SingletonInputs,
|
||||
) -> TokenInputs | MultiModalInputs:
|
||||
def _validate_dec_inputs(self, inputs: SingletonInputs) -> DecoderInputs:
|
||||
if inputs["type"] == "embeds":
|
||||
raise ValueError(
|
||||
"Embedding inputs are not supported for encoder-decoder models"
|
||||
@@ -403,14 +426,15 @@ class InputPreprocessor:
|
||||
encoder_inputs: SingletonInputs,
|
||||
decoder_inputs: SingletonInputs | None = None,
|
||||
) -> EncoderDecoderInputs:
|
||||
if decoder_inputs is None:
|
||||
decoder_inputs = encoder_inputs
|
||||
|
||||
enc_inputs = self._validate_enc_inputs(encoder_inputs)
|
||||
dec_inputs = self._validate_dec_inputs(decoder_inputs)
|
||||
|
||||
enc_inputs_new: TokenInputs | MultiModalEncDecInputs
|
||||
dec_inputs_new: TokenInputs | MultiModalInputs
|
||||
if decoder_inputs is None:
|
||||
dec_inputs: DecoderInputs = enc_inputs # type: ignore[assignment]
|
||||
else:
|
||||
dec_inputs = self._validate_dec_inputs(decoder_inputs)
|
||||
|
||||
enc_inputs_new: EncoderInputs
|
||||
dec_inputs_new: DecoderInputs
|
||||
|
||||
if enc_inputs["type"] == "multimodal":
|
||||
enc_inputs_new = token_inputs(enc_inputs["encoder_prompt_token_ids"])
|
||||
@@ -437,7 +461,7 @@ class InputPreprocessor:
|
||||
|
||||
def _process_encoder_decoder_prompt(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
prompt: EncoderDecoderDictPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
@@ -448,24 +472,6 @@ class InputPreprocessor:
|
||||
[`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
|
||||
instance.
|
||||
|
||||
There are two types of input prompts:
|
||||
singleton prompts which carry only the
|
||||
encoder prompt, and explicit encoder/decoder
|
||||
prompts which carry both the encoder and the
|
||||
decoder prompts as member variables.
|
||||
|
||||
This function handles the following scenarios:
|
||||
* Singleton encoder prompt: extract encoder prompt
|
||||
token ids & infer default decoder prompt token ids
|
||||
* Explicit encoder/decoder prompt: extract encoder
|
||||
and decoder prompt token ids
|
||||
|
||||
Note that for Explicit encoder/decoder prompts,
|
||||
each sub-prompt (encoder or decoder prompt) can
|
||||
have any possible singleton type; thus this
|
||||
method relies on helper functions to obtain
|
||||
token ids for the sub-prompts.
|
||||
|
||||
Arguments:
|
||||
|
||||
* prompt: an input prompt
|
||||
@@ -475,7 +481,8 @@ class InputPreprocessor:
|
||||
* [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
|
||||
instance
|
||||
"""
|
||||
encoder_prompt, decoder_prompt = split_enc_dec_prompt(prompt)
|
||||
encoder_prompt = prompt["encoder_prompt"]
|
||||
decoder_prompt = prompt["decoder_prompt"]
|
||||
|
||||
return self._build_enc_dec_inputs(
|
||||
encoder_inputs=self._prompt_to_llm_inputs(
|
||||
@@ -495,7 +502,7 @@ class InputPreprocessor:
|
||||
|
||||
def _process_decoder_only_prompt(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
prompt: DecoderOnlyDictPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
@@ -521,7 +528,7 @@ class InputPreprocessor:
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
prompt: PromptType | DictPrompt | TokPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
@@ -530,25 +537,20 @@ class InputPreprocessor:
|
||||
# Encoder-decoder model requires special mapping of
|
||||
# input prompts to encoder & decoder.
|
||||
return self._process_encoder_decoder_prompt(
|
||||
prompt,
|
||||
parse_enc_dec_prompt(prompt),
|
||||
tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
raise ValueError(
|
||||
"Cannot pass encoder-decoder prompt to decoder-only models"
|
||||
)
|
||||
|
||||
return self._process_decoder_only_prompt(
|
||||
prompt,
|
||||
parse_dec_only_prompt(prompt),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
prompt: PromptType | DictPrompt | TokPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import (
|
||||
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
from typing_extensions import NotRequired, TypeVar
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
@@ -32,9 +32,13 @@ if TYPE_CHECKING:
|
||||
import torch
|
||||
import torch.types
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
from vllm.inputs.data import _InputOptions
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
_InputOptions = dict
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
|
||||
@@ -1059,7 +1063,7 @@ A dictionary containing per-item placeholder ranges for each modality.
|
||||
"""
|
||||
|
||||
|
||||
class MultiModalInputs(TypedDict):
|
||||
class MultiModalInputs(_InputOptions):
|
||||
"""
|
||||
Represents the outputs of
|
||||
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
|
||||
@@ -1084,11 +1088,6 @@ class MultiModalInputs(TypedDict):
|
||||
`prompt_token_ids`.
|
||||
"""
|
||||
|
||||
cache_salt: NotRequired[str]
|
||||
"""
|
||||
Optional cache salt to be used for prefix caching.
|
||||
"""
|
||||
|
||||
|
||||
class MultiModalEncDecInputs(MultiModalInputs):
|
||||
"""
|
||||
|
||||
@@ -19,6 +19,7 @@ if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers.inputs import DictPrompt, TokPrompt
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.v1.attention.selector import AttentionSelectorConfig
|
||||
@@ -565,7 +566,7 @@ class Platform:
|
||||
@classmethod
|
||||
def validate_request(
|
||||
cls,
|
||||
prompt: "PromptType",
|
||||
prompt: "PromptType | DictPrompt | TokPrompt",
|
||||
params: "SamplingParams | PoolingParams",
|
||||
processed_inputs: "ProcessorInputs",
|
||||
) -> None:
|
||||
|
||||
@@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import (
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
|
||||
|
||||
from .inputs import DictPrompt
|
||||
from .inputs.preprocess import parse_dec_only_prompt
|
||||
from .params import ChatParams
|
||||
from .protocol import BaseRenderer
|
||||
|
||||
@@ -61,7 +62,7 @@ class DeepseekV32Renderer(BaseRenderer):
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
@@ -75,7 +76,7 @@ class DeepseekV32Renderer(BaseRenderer):
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = self.render_completion(prompt_raw)
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
@@ -87,7 +88,7 @@ class DeepseekV32Renderer(BaseRenderer):
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
@@ -101,7 +102,7 @@ class DeepseekV32Renderer(BaseRenderer):
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = self.render_completion(prompt_raw)
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
|
||||
@@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import (
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.grok2 import Grok2Tokenizer
|
||||
|
||||
from .inputs import DictPrompt
|
||||
from .inputs.preprocess import parse_dec_only_prompt
|
||||
from .params import ChatParams
|
||||
from .protocol import BaseRenderer
|
||||
|
||||
@@ -61,7 +62,7 @@ class Grok2Renderer(BaseRenderer):
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
@@ -75,7 +76,7 @@ class Grok2Renderer(BaseRenderer):
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = self.render_completion(prompt_raw)
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
@@ -87,7 +88,7 @@ class Grok2Renderer(BaseRenderer):
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
@@ -101,7 +102,7 @@ class Grok2Renderer(BaseRenderer):
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = self.render_completion(prompt_raw)
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
|
||||
@@ -25,7 +25,6 @@ from vllm.entrypoints.chat_utils import (
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer
|
||||
@@ -33,6 +32,8 @@ from vllm.transformers_utils.chat_templates import get_chat_template_fallback_pa
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
|
||||
from .inputs import DictPrompt
|
||||
from .inputs.preprocess import parse_dec_only_prompt
|
||||
from .params import ChatParams
|
||||
from .protocol import BaseRenderer
|
||||
|
||||
@@ -632,7 +633,7 @@ class HfRenderer(BaseRenderer):
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
model_config = self.config
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
@@ -674,7 +675,7 @@ class HfRenderer(BaseRenderer):
|
||||
video_placeholder,
|
||||
)
|
||||
|
||||
prompt = self.render_completion(prompt_raw)
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
@@ -686,7 +687,7 @@ class HfRenderer(BaseRenderer):
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
model_config = self.config
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
@@ -726,7 +727,7 @@ class HfRenderer(BaseRenderer):
|
||||
video_placeholder,
|
||||
)
|
||||
|
||||
prompt = self.render_completion(prompt_raw)
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
|
||||
33
vllm/renderers/inputs/__init__.py
Normal file
33
vllm/renderers/inputs/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from .preprocess import (
|
||||
DecoderDictPrompt,
|
||||
DecoderOnlyDictPrompt,
|
||||
DictPrompt,
|
||||
EncoderDecoderDictPrompt,
|
||||
EncoderDictPrompt,
|
||||
SingletonDictPrompt,
|
||||
)
|
||||
from .tokenize import (
|
||||
DecoderOnlyTokPrompt,
|
||||
DecoderTokPrompt,
|
||||
EncoderDecoderTokPrompt,
|
||||
EncoderTokPrompt,
|
||||
SingletonTokPrompt,
|
||||
TokPrompt,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DecoderOnlyDictPrompt",
|
||||
"EncoderDictPrompt",
|
||||
"DecoderDictPrompt",
|
||||
"EncoderDecoderDictPrompt",
|
||||
"SingletonDictPrompt",
|
||||
"DictPrompt",
|
||||
"DecoderOnlyTokPrompt",
|
||||
"EncoderTokPrompt",
|
||||
"DecoderTokPrompt",
|
||||
"EncoderDecoderTokPrompt",
|
||||
"SingletonTokPrompt",
|
||||
"TokPrompt",
|
||||
]
|
||||
255
vllm/renderers/inputs/preprocess.py
Normal file
255
vllm/renderers/inputs/preprocess.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""
|
||||
Schemas and utilites for preprocessing inputs.
|
||||
"""
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload
|
||||
|
||||
from vllm.inputs import (
|
||||
EmbedsPrompt,
|
||||
ExplicitEncoderDecoderPrompt,
|
||||
PromptType,
|
||||
SingletonPrompt,
|
||||
TextPrompt,
|
||||
TokensPrompt,
|
||||
)
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
|
||||
|
||||
@overload
|
||||
def prompt_to_seq(
|
||||
prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
|
||||
) -> Sequence[SingletonPrompt]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def prompt_to_seq( # type: ignore[misc]
|
||||
prompt_or_prompts: ExplicitEncoderDecoderPrompt
|
||||
| Sequence[ExplicitEncoderDecoderPrompt],
|
||||
) -> Sequence[ExplicitEncoderDecoderPrompt]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def prompt_to_seq( # type: ignore[misc]
|
||||
prompt_or_prompts: PromptType | Sequence[PromptType],
|
||||
) -> Sequence[PromptType]: ...
|
||||
|
||||
|
||||
def prompt_to_seq(
|
||||
prompt_or_prompts: PromptType | bytes | Sequence[PromptType | bytes],
|
||||
) -> Sequence[PromptType]:
|
||||
if isinstance(prompt_or_prompts, (dict, str, bytes)) or (
|
||||
len(prompt_or_prompts) > 0 and is_list_of(prompt_or_prompts, int)
|
||||
):
|
||||
return [prompt_or_prompts] # type: ignore[list-item]
|
||||
|
||||
return prompt_or_prompts # type: ignore[return-value]
|
||||
|
||||
|
||||
def conversation_to_seq(
|
||||
conversation_or_conversations: list["ChatCompletionMessageParam"]
|
||||
| Sequence[list["ChatCompletionMessageParam"]],
|
||||
) -> Sequence[list["ChatCompletionMessageParam"]]:
|
||||
if len(conversation_or_conversations) > 0 and is_list_of(
|
||||
conversation_or_conversations, dict
|
||||
):
|
||||
return [conversation_or_conversations] # type: ignore[list-item]
|
||||
|
||||
return conversation_or_conversations # type: ignore[return-value]
|
||||
|
||||
|
||||
DecoderOnlyDictPrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
|
||||
"""
|
||||
A [`DecoderOnlyPrompt`][vllm.inputs.data.DecoderOnlyPrompt]
|
||||
that has been standardized into a dictionary.
|
||||
"""
|
||||
|
||||
|
||||
EncoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
|
||||
"""
|
||||
A [`EncoderPrompt`][vllm.inputs.data.EncoderPrompt]
|
||||
that has been standardized into a dictionary.
|
||||
"""
|
||||
|
||||
|
||||
DecoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
|
||||
"""
|
||||
A [`DecoderPrompt`][vllm.inputs.data.DecoderPrompt]
|
||||
that has been standardized into a dictionary.
|
||||
"""
|
||||
|
||||
|
||||
class EncoderDecoderDictPrompt(TypedDict):
|
||||
"""
|
||||
A [`EncoderDecoderPrompt`][vllm.inputs.data.EncoderDecoderPrompt]
|
||||
that has been standardized into a dictionary.
|
||||
"""
|
||||
|
||||
encoder_prompt: EncoderDictPrompt
|
||||
|
||||
decoder_prompt: DecoderDictPrompt | None
|
||||
|
||||
|
||||
SingletonDictPrompt: TypeAlias = (
|
||||
DecoderOnlyDictPrompt | EncoderDictPrompt | DecoderDictPrompt
|
||||
)
|
||||
"""
|
||||
A [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt]
|
||||
that has been standardized into a dictionary.
|
||||
"""
|
||||
|
||||
|
||||
DictPrompt: TypeAlias = DecoderOnlyDictPrompt | EncoderDecoderDictPrompt
|
||||
"""
|
||||
A [`PromptType`][vllm.inputs.data.PromptType]
|
||||
that has been standardized into a dictionary.
|
||||
"""
|
||||
|
||||
|
||||
def parse_dec_only_prompt(prompt: object) -> DecoderOnlyDictPrompt:
|
||||
"""
|
||||
Parse a prompt for a decoder-only model and normalize it to a dictionary.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
return TextPrompt(prompt=prompt)
|
||||
|
||||
if isinstance(prompt, list):
|
||||
if not is_list_of(prompt, int):
|
||||
raise TypeError("Token prompt should be a list of integers")
|
||||
|
||||
return TokensPrompt(prompt_token_ids=prompt)
|
||||
|
||||
if isinstance(prompt, dict):
|
||||
if "encoder_prompt" in prompt:
|
||||
raise TypeError("Cannot pass encoder-decoder prompt to decoder-only models")
|
||||
|
||||
if (
|
||||
"prompt" in prompt
|
||||
or "prompt_token_ids" in prompt
|
||||
or "prompt_embeds" in prompt
|
||||
):
|
||||
return prompt # type: ignore[return-value]
|
||||
|
||||
raise TypeError("Prompt dictionary must contain text, tokens, or embeddings")
|
||||
|
||||
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
|
||||
|
||||
|
||||
def _parse_enc_prompt(prompt: object) -> EncoderDictPrompt:
|
||||
if isinstance(prompt, str):
|
||||
return TextPrompt(prompt=prompt)
|
||||
|
||||
if isinstance(prompt, list):
|
||||
if not is_list_of(prompt, int):
|
||||
raise TypeError("Token prompt should be a list of integers")
|
||||
|
||||
return TokensPrompt(prompt_token_ids=prompt)
|
||||
|
||||
if isinstance(prompt, dict):
|
||||
if "prompt_embeds" in prompt:
|
||||
raise TypeError("Cannot pass embeddings prompt to encoder-decoder models")
|
||||
|
||||
if "prompt" in prompt or "prompt_token_ids" in prompt:
|
||||
return prompt # type: ignore[return-value]
|
||||
|
||||
raise TypeError("Prompt dictionary must contain text or tokens")
|
||||
|
||||
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
|
||||
|
||||
|
||||
def _parse_dec_prompt(prompt: object) -> DecoderDictPrompt:
|
||||
if isinstance(prompt, str):
|
||||
return TextPrompt(prompt=prompt)
|
||||
|
||||
if isinstance(prompt, list):
|
||||
if not is_list_of(prompt, int):
|
||||
raise TypeError("Token prompt should be a list of integers")
|
||||
|
||||
return TokensPrompt(prompt_token_ids=prompt)
|
||||
|
||||
if isinstance(prompt, dict):
|
||||
if "prompt_embeds" in prompt:
|
||||
raise TypeError("Cannot pass embeddings prompt to encoder-decoder models")
|
||||
|
||||
if (
|
||||
"multi_modal_data" in prompt
|
||||
or "mm_processor_kwargs" in prompt
|
||||
or "multi_modal_uuids" in prompt
|
||||
):
|
||||
raise TypeError("Cannot pass multi-modal inputs to decoder prompt")
|
||||
|
||||
if "prompt" in prompt or "prompt_token_ids" in prompt:
|
||||
return prompt # type: ignore[return-value]
|
||||
|
||||
raise TypeError("Prompt dictionary must contain text or tokens")
|
||||
|
||||
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
|
||||
|
||||
|
||||
def parse_enc_dec_prompt(prompt: object) -> EncoderDecoderDictPrompt:
|
||||
"""
|
||||
Parse a prompt for an encoder-decoder model and normalize it to a dictionary.
|
||||
"""
|
||||
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
|
||||
enc_prompt: object = prompt["encoder_prompt"] # type: ignore[typeddict-item]
|
||||
dec_prompt: object | None = prompt["decoder_prompt"] # type: ignore[typeddict-item]
|
||||
else:
|
||||
enc_prompt = prompt
|
||||
dec_prompt = None
|
||||
|
||||
return EncoderDecoderDictPrompt(
|
||||
encoder_prompt=_parse_enc_prompt(enc_prompt),
|
||||
decoder_prompt=None if dec_prompt is None else _parse_dec_prompt(dec_prompt),
|
||||
)
|
||||
|
||||
|
||||
def parse_model_prompt(model_config: "ModelConfig", prompt: object):
|
||||
if model_config.is_encoder_decoder:
|
||||
return parse_enc_dec_prompt(prompt)
|
||||
|
||||
return parse_dec_only_prompt(prompt)
|
||||
|
||||
|
||||
class PromptComponents(NamedTuple):
|
||||
text: str | None = None
|
||||
token_ids: list[int] | None = None
|
||||
embeds: "torch.Tensor | None" = None
|
||||
|
||||
|
||||
def extract_prompt_components(
|
||||
model_config: "ModelConfig",
|
||||
prompt: object,
|
||||
) -> PromptComponents:
|
||||
target_prompt = (
|
||||
parse_enc_dec_prompt(prompt)["encoder_prompt"]
|
||||
if model_config.is_encoder_decoder
|
||||
else parse_dec_only_prompt(prompt)
|
||||
)
|
||||
|
||||
return PromptComponents(
|
||||
text=target_prompt.get("prompt"),
|
||||
token_ids=target_prompt.get("prompt_token_ids"), # type: ignore[arg-type]
|
||||
embeds=target_prompt.get("prompt_embeds"),
|
||||
)
|
||||
|
||||
|
||||
def extract_prompt_len(model_config: "ModelConfig", prompt: object):
|
||||
target_prompt = (
|
||||
parse_enc_dec_prompt(prompt)["encoder_prompt"]
|
||||
if model_config.is_encoder_decoder
|
||||
else parse_dec_only_prompt(prompt)
|
||||
)
|
||||
|
||||
return length_from_prompt_token_ids_or_embeds(
|
||||
target_prompt.get("prompt_token_ids"), # type: ignore[arg-type]
|
||||
target_prompt.get("prompt_embeds"),
|
||||
)
|
||||
57
vllm/renderers/inputs/tokenize.py
Normal file
57
vllm/renderers/inputs/tokenize.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Schemas and utilites for tokenization inputs.
|
||||
"""
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TypeAlias, TypedDict
|
||||
|
||||
from vllm.inputs import EmbedsPrompt, TokensPrompt
|
||||
|
||||
DecoderOnlyTokPrompt: TypeAlias = TokensPrompt | EmbedsPrompt
|
||||
"""
|
||||
A [`DecoderOnlyDictPrompt`][vllm.renderers.inputs.preprocess.DecoderOnlyDictPrompt]
|
||||
that has been tokenized.
|
||||
"""
|
||||
|
||||
|
||||
EncoderTokPrompt: TypeAlias = TokensPrompt
|
||||
"""
|
||||
A [`EncoderDictPrompt`][vllm.renderers.inputs.preprocess.EncoderDictPrompt]
|
||||
that has been tokenized.
|
||||
"""
|
||||
|
||||
|
||||
DecoderTokPrompt: TypeAlias = TokensPrompt
|
||||
"""
|
||||
A [`DecoderDictPrompt`][vllm.renderers.inputs.preprocess.DecoderDictPrompt]
|
||||
that has been tokenized.
|
||||
"""
|
||||
|
||||
|
||||
class EncoderDecoderTokPrompt(TypedDict):
|
||||
"""
|
||||
A
|
||||
[`EncoderDecoderDictPrompt`][vllm.renderers.inputs.preprocess.EncoderDecoderDictPrompt]
|
||||
that has been tokenized.
|
||||
"""
|
||||
|
||||
encoder_prompt: EncoderTokPrompt
|
||||
|
||||
decoder_prompt: DecoderTokPrompt | None
|
||||
|
||||
|
||||
SingletonTokPrompt: TypeAlias = (
|
||||
DecoderOnlyTokPrompt | EncoderTokPrompt | DecoderTokPrompt
|
||||
)
|
||||
"""
|
||||
A [`SingletonDictPrompt`][vllm.renderers.inputs.preprocess.SingletonDictPrompt]
|
||||
that has been tokenized.
|
||||
"""
|
||||
|
||||
|
||||
TokPrompt: TypeAlias = DecoderOnlyTokPrompt | EncoderDecoderTokPrompt
|
||||
"""
|
||||
A [`DictPrompt`][vllm.renderers.inputs.preprocess.DictPrompt]
|
||||
that has been tokenized.
|
||||
"""
|
||||
@@ -10,12 +10,13 @@ from vllm.entrypoints.chat_utils import (
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils.async_utils import make_async
|
||||
|
||||
from .inputs import DictPrompt
|
||||
from .inputs.preprocess import parse_dec_only_prompt
|
||||
from .params import ChatParams
|
||||
from .protocol import BaseRenderer
|
||||
|
||||
@@ -95,7 +96,7 @@ class MistralRenderer(BaseRenderer):
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
@@ -109,7 +110,7 @@ class MistralRenderer(BaseRenderer):
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = self.render_completion(prompt_raw)
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
@@ -121,7 +122,7 @@ class MistralRenderer(BaseRenderer):
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
@@ -135,7 +136,7 @@ class MistralRenderer(BaseRenderer):
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = self.render_completion(prompt_raw)
|
||||
prompt = parse_dec_only_prompt(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
|
||||
@@ -2,14 +2,20 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, overload
|
||||
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
|
||||
from .embed_utils import safe_load_prompt_embeds
|
||||
from .inputs import (
|
||||
DictPrompt,
|
||||
EncoderDecoderDictPrompt,
|
||||
EncoderDecoderTokPrompt,
|
||||
TokPrompt,
|
||||
)
|
||||
from .params import ChatParams, TokenizeParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -57,140 +63,217 @@ class BaseRenderer(ABC):
|
||||
return self._async_tokenizer
|
||||
|
||||
# Step 1: Convert raw inputs to prompts
|
||||
def render_completion(
|
||||
def render_prompt(
|
||||
self,
|
||||
prompt_raw: str | list[int] | bytes,
|
||||
) -> TextPrompt | TokensPrompt | EmbedsPrompt:
|
||||
error_msg = "Each prompt must be a string or an array of tokens"
|
||||
prompt: DictPrompt | bytes,
|
||||
) -> DictPrompt:
|
||||
if isinstance(prompt, bytes):
|
||||
embeds = safe_load_prompt_embeds(self.config, prompt)
|
||||
prompt = EmbedsPrompt(prompt_embeds=embeds)
|
||||
|
||||
if isinstance(prompt_raw, str):
|
||||
return TextPrompt(prompt=prompt_raw)
|
||||
return prompt
|
||||
|
||||
if isinstance(prompt_raw, list):
|
||||
if not is_list_of(prompt_raw, int):
|
||||
raise TypeError(error_msg)
|
||||
|
||||
return TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
|
||||
if isinstance(prompt_raw, bytes):
|
||||
embeds = safe_load_prompt_embeds(self.config, prompt_raw)
|
||||
return EmbedsPrompt(prompt_embeds=embeds)
|
||||
|
||||
raise TypeError(error_msg)
|
||||
|
||||
def render_completions(
|
||||
def render_prompts(
|
||||
self,
|
||||
prompt_input: str | list[str] | list[int] | list[list[int]] | None = None,
|
||||
prompt_embeds: bytes | list[bytes] | None = None,
|
||||
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
prompts_raw = list[str | list[int] | bytes]()
|
||||
|
||||
if prompt_embeds is not None: # embeds take higher priority
|
||||
if isinstance(prompt_embeds, bytes):
|
||||
prompts_raw.append(prompt_embeds)
|
||||
else:
|
||||
prompts_raw.extend(prompt_embeds)
|
||||
|
||||
if prompt_input is not None:
|
||||
if isinstance(prompt_input, str) or (
|
||||
len(prompt_input) > 0 and is_list_of(prompt_input, int)
|
||||
):
|
||||
prompts_raw.append(prompt_input) # type: ignore[arg-type]
|
||||
else:
|
||||
prompts_raw.extend(prompt_input) # type: ignore[arg-type]
|
||||
|
||||
if len(prompts_raw) == 0:
|
||||
prompts: Sequence[DictPrompt | bytes],
|
||||
) -> list[DictPrompt]:
|
||||
if len(prompts) == 0:
|
||||
raise ValueError("You must pass at least one prompt")
|
||||
|
||||
return [self.render_completion(prompt) for prompt in prompts_raw]
|
||||
return [self.render_prompt(prompt) for prompt in prompts]
|
||||
|
||||
async def render_completions_async(
|
||||
async def render_prompts_async(
|
||||
self,
|
||||
prompt_input: str | list[str] | list[int] | list[list[int]] | None = None,
|
||||
prompt_embeds: bytes | list[bytes] | None = None,
|
||||
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
return self.render_completions(prompt_input, prompt_embeds)
|
||||
prompts: Sequence[DictPrompt | bytes],
|
||||
) -> list[DictPrompt]:
|
||||
return self.render_prompts(prompts)
|
||||
|
||||
@abstractmethod
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
params: ChatParams,
|
||||
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
) -> tuple[list["ConversationMessage"], DictPrompt]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
params: ChatParams,
|
||||
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
) -> tuple[list["ConversationMessage"], DictPrompt]:
|
||||
return self.render_messages(messages, params)
|
||||
|
||||
# Step 2: Tokenize prompts if necessary
|
||||
def _tokenize_prompt(
|
||||
self,
|
||||
prompt: TextPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokensPrompt:
|
||||
tokenizer = self.get_tokenizer()
|
||||
prompt_token_ids = tokenizer.encode(
|
||||
prompt["prompt"],
|
||||
**params.get_encode_kwargs(),
|
||||
)
|
||||
|
||||
return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
|
||||
|
||||
async def _tokenize_prompt_async(
|
||||
self,
|
||||
prompt: TextPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokensPrompt:
|
||||
tokenizer = self.get_async_tokenizer()
|
||||
prompt_token_ids = await tokenizer.encode(
|
||||
prompt["prompt"],
|
||||
**params.get_encode_kwargs(),
|
||||
)
|
||||
|
||||
return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
|
||||
|
||||
def _detokenize_prompt(self, prompt: TokensPrompt) -> TokensPrompt:
|
||||
tokenizer = self.get_tokenizer()
|
||||
prompt["prompt"] = tokenizer.decode(prompt["prompt_token_ids"])
|
||||
|
||||
return prompt
|
||||
|
||||
async def _detokenize_prompt_async(self, prompt: TokensPrompt) -> TokensPrompt:
|
||||
tokenizer = self.get_async_tokenizer()
|
||||
prompt["prompt"] = await tokenizer.decode(prompt["prompt_token_ids"])
|
||||
|
||||
return prompt
|
||||
|
||||
def _tokenize_enc_dec_prompt(
|
||||
self,
|
||||
prompt: EncoderDecoderDictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EncoderDecoderTokPrompt:
|
||||
enc_prompt, dec_prompt = (
|
||||
self.tokenize_prompt(prompt["encoder_prompt"], params),
|
||||
(
|
||||
None
|
||||
if prompt["decoder_prompt"] is None
|
||||
else self.tokenize_prompt(prompt["decoder_prompt"], params)
|
||||
),
|
||||
)
|
||||
|
||||
return EncoderDecoderTokPrompt(
|
||||
encoder_prompt=enc_prompt,
|
||||
decoder_prompt=dec_prompt,
|
||||
)
|
||||
|
||||
async def _tokenize_enc_dec_prompt_async(
|
||||
self,
|
||||
prompt: EncoderDecoderDictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EncoderDecoderTokPrompt:
|
||||
enc_prompt, dec_prompt = await asyncio.gather(
|
||||
self.tokenize_prompt_async(prompt["encoder_prompt"], params),
|
||||
(
|
||||
asyncio.sleep(0)
|
||||
if prompt["decoder_prompt"] is None
|
||||
else self.tokenize_prompt_async(prompt["decoder_prompt"], params)
|
||||
),
|
||||
)
|
||||
|
||||
return EncoderDecoderTokPrompt(
|
||||
encoder_prompt=enc_prompt,
|
||||
decoder_prompt=dec_prompt,
|
||||
)
|
||||
|
||||
@overload
|
||||
def tokenize_prompt(
|
||||
self,
|
||||
prompt: TextPrompt | TokensPrompt | EmbedsPrompt,
|
||||
prompt: TextPrompt | TokensPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokensPrompt | EmbedsPrompt:
|
||||
) -> TokensPrompt: ...
|
||||
|
||||
@overload
|
||||
def tokenize_prompt( # type: ignore[misc]
|
||||
self,
|
||||
prompt: EmbedsPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EmbedsPrompt: ...
|
||||
|
||||
@overload
|
||||
def tokenize_prompt( # type: ignore[misc]
|
||||
self,
|
||||
prompt: EncoderDecoderDictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EncoderDecoderTokPrompt: ...
|
||||
|
||||
def tokenize_prompt(
|
||||
self,
|
||||
prompt: DictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokPrompt:
|
||||
if "encoder_prompt" in prompt:
|
||||
return self._tokenize_enc_dec_prompt(prompt, params) # type: ignore[arg-type]
|
||||
|
||||
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
|
||||
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
prompt_token_ids = tokenizer.encode(
|
||||
prompt["prompt"],
|
||||
**params.get_encode_kwargs(),
|
||||
)
|
||||
|
||||
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
|
||||
prompt = self._tokenize_prompt(prompt, params)
|
||||
|
||||
if params.needs_detokenization and "prompt" not in prompt:
|
||||
if "prompt_token_ids" not in prompt:
|
||||
raise RuntimeError("Cannot run detokenization on embeddings")
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
prompt_text = tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
|
||||
prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key]
|
||||
prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type]
|
||||
|
||||
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
||||
|
||||
def tokenize_prompts(
|
||||
self,
|
||||
prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt],
|
||||
prompts: Sequence[DictPrompt],
|
||||
params: TokenizeParams,
|
||||
) -> list[TokensPrompt | EmbedsPrompt]:
|
||||
) -> list[TokPrompt]:
|
||||
return [self.tokenize_prompt(prompt, params) for prompt in prompts]
|
||||
|
||||
@overload
|
||||
async def tokenize_prompt_async(
|
||||
self,
|
||||
prompt: TextPrompt | TokensPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokensPrompt: ...
|
||||
|
||||
@overload
|
||||
async def tokenize_prompt_async( # type: ignore[misc]
|
||||
self,
|
||||
prompt: EmbedsPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EmbedsPrompt: ...
|
||||
|
||||
@overload
|
||||
async def tokenize_prompt_async( # type: ignore[misc]
|
||||
self,
|
||||
prompt: EncoderDecoderDictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EncoderDecoderTokPrompt: ...
|
||||
|
||||
async def tokenize_prompt_async(
|
||||
self,
|
||||
prompt: TextPrompt | TokensPrompt | EmbedsPrompt,
|
||||
prompt: DictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokensPrompt | EmbedsPrompt:
|
||||
) -> TokPrompt:
|
||||
if "encoder_prompt" in prompt:
|
||||
return await self._tokenize_enc_dec_prompt_async(prompt, params) # type: ignore[arg-type]
|
||||
|
||||
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
|
||||
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
|
||||
|
||||
tokenizer = self.get_async_tokenizer()
|
||||
prompt_token_ids = await tokenizer.encode(
|
||||
prompt["prompt"],
|
||||
**params.get_encode_kwargs(),
|
||||
)
|
||||
|
||||
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
|
||||
prompt = await self._tokenize_prompt_async(prompt, params)
|
||||
|
||||
if params.needs_detokenization and "prompt" not in prompt:
|
||||
if "prompt_token_ids" not in prompt:
|
||||
raise RuntimeError("Cannot run detokenization on embeddings")
|
||||
|
||||
tokenizer = self.get_async_tokenizer()
|
||||
prompt_text = await tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
|
||||
prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key]
|
||||
prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type]
|
||||
|
||||
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
||||
|
||||
async def tokenize_prompts_async(
|
||||
self,
|
||||
prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt],
|
||||
prompts: Sequence[DictPrompt],
|
||||
params: TokenizeParams,
|
||||
) -> list[TokensPrompt | EmbedsPrompt]:
|
||||
) -> list[TokPrompt]:
|
||||
return await asyncio.gather(
|
||||
*(self.tokenize_prompt_async(prompt, params) for prompt in prompts)
|
||||
)
|
||||
|
||||
@@ -9,10 +9,11 @@ from vllm.entrypoints.chat_utils import (
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
from .inputs import DictPrompt
|
||||
from .inputs.preprocess import parse_dec_only_prompt
|
||||
from .params import ChatParams
|
||||
from .protocol import BaseRenderer
|
||||
|
||||
@@ -45,7 +46,7 @@ class TerratorchRenderer(BaseRenderer):
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
model_config = self.config
|
||||
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@@ -54,7 +55,7 @@ class TerratorchRenderer(BaseRenderer):
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt = self.render_completion([1]) # Dummy token IDs
|
||||
prompt = parse_dec_only_prompt([1]) # Dummy token IDs
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
@@ -66,7 +67,7 @@ class TerratorchRenderer(BaseRenderer):
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||
model_config = self.config
|
||||
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
@@ -75,7 +76,7 @@ class TerratorchRenderer(BaseRenderer):
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt = self.render_completion([1]) # Dummy token IDs
|
||||
prompt = parse_dec_only_prompt([1]) # Dummy token IDs
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
|
||||
@@ -28,6 +28,8 @@ from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import BaseRenderer, merge_kwargs
|
||||
from vllm.renderers.inputs import DictPrompt, TokPrompt
|
||||
from vllm.renderers.inputs.preprocess import extract_prompt_components
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
@@ -42,7 +44,6 @@ from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||
from vllm.v1.engine.input_processor import InputProcessor
|
||||
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.utils import get_prompt_text
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.metrics.loggers import (
|
||||
StatLoggerFactory,
|
||||
@@ -284,7 +285,11 @@ class AsyncLLM(EngineClient):
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
|
||||
prompt: EngineCoreRequest
|
||||
| PromptType
|
||||
| DictPrompt
|
||||
| TokPrompt
|
||||
| AsyncGenerator[StreamingInput, None],
|
||||
params: SamplingParams | PoolingParams,
|
||||
arrival_time: float | None = None,
|
||||
lora_request: LoRARequest | None = None,
|
||||
@@ -367,7 +372,7 @@ class AsyncLLM(EngineClient):
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
supported_tasks=await self.get_supported_tasks(),
|
||||
)
|
||||
prompt_text = get_prompt_text(prompt)
|
||||
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
|
||||
|
||||
self.input_processor.assign_request_id(request)
|
||||
|
||||
@@ -484,7 +489,9 @@ class AsyncLLM(EngineClient):
|
||||
raise ValueError(
|
||||
"prompt_embeds not supported for streaming inputs"
|
||||
)
|
||||
prompt_text = get_prompt_text(input_chunk.prompt)
|
||||
prompt_text, _, _ = extract_prompt_components(
|
||||
self.model_config, input_chunk.prompt
|
||||
)
|
||||
await self._add_request(req, prompt_text, None, 0, queue)
|
||||
except (asyncio.CancelledError, GeneratorExit):
|
||||
cancelled = True
|
||||
@@ -528,7 +535,11 @@ class AsyncLLM(EngineClient):
|
||||
# re-multiplexed in the API server anyhow.
|
||||
async def generate(
|
||||
self,
|
||||
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
|
||||
prompt: EngineCoreRequest
|
||||
| PromptType
|
||||
| DictPrompt
|
||||
| TokPrompt
|
||||
| AsyncGenerator[StreamingInput, None],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
*,
|
||||
@@ -769,7 +780,7 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
async def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
prompt: PromptType | DictPrompt | TokPrompt,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: LoRARequest | None = None,
|
||||
|
||||
@@ -7,14 +7,13 @@ from typing import Any, Literal, cast
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import (
|
||||
from vllm.inputs.data import (
|
||||
ProcessorInputs,
|
||||
PromptType,
|
||||
SingletonInputs,
|
||||
SingletonPrompt,
|
||||
TextPrompt,
|
||||
)
|
||||
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt, split_enc_dec_inputs
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -30,6 +29,7 @@ from vllm.multimodal.processing.context import set_request_id
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.renderers.inputs import DictPrompt, TokPrompt
|
||||
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
@@ -243,8 +243,8 @@ class InputProcessor:
|
||||
return mm_processor.info.parse_mm_data(mm_data)
|
||||
|
||||
def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
|
||||
if isinstance(prompt, str):
|
||||
prompt = TextPrompt(prompt=prompt)
|
||||
if not isinstance(prompt, dict):
|
||||
return
|
||||
|
||||
mm_data = cast(MultiModalDataDict, prompt.get("multi_modal_data") or {})
|
||||
mm_uuids = cast(MultiModalUUIDDict, prompt.get("multi_modal_uuids") or {})
|
||||
@@ -297,7 +297,7 @@ class InputProcessor:
|
||||
f"multi_modal_uuids[{modality!r}] is missing."
|
||||
)
|
||||
|
||||
def _validate_mm_uuids(self, prompt: PromptType) -> None:
|
||||
def _validate_mm_uuids(self, prompt: PromptType | DictPrompt | TokPrompt) -> None:
|
||||
"""
|
||||
Validate that user-provided multi_modal_uuids align with
|
||||
multi_modal_data in the incoming request prompt(s).
|
||||
@@ -305,10 +305,10 @@ class InputProcessor:
|
||||
auto-hashed downstream.
|
||||
"""
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
self._validate_singleton_mm_uuids(prompt["encoder_prompt"])
|
||||
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
|
||||
self._validate_singleton_mm_uuids(prompt["encoder_prompt"]) # type: ignore[typeddict-item]
|
||||
|
||||
if (dec_prompt := prompt["decoder_prompt"]) is not None:
|
||||
if (dec_prompt := prompt["decoder_prompt"]) is not None: # type: ignore[typeddict-item]
|
||||
self._validate_singleton_mm_uuids(dec_prompt)
|
||||
else:
|
||||
self._validate_singleton_mm_uuids(prompt)
|
||||
@@ -449,21 +449,23 @@ class InputProcessor:
|
||||
def _extract_singleton_mm_data(
|
||||
self, prompt: SingletonPrompt
|
||||
) -> MultiModalDataDict | None:
|
||||
if isinstance(prompt, str):
|
||||
if not isinstance(prompt, dict):
|
||||
return None
|
||||
|
||||
return prompt.get("multi_modal_data") # type: ignore[return-value]
|
||||
return prompt.get("multi_modal_data")
|
||||
|
||||
def _extract_mm_data(self, prompt: PromptType) -> MultiModalDataDict | None:
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
return self._extract_singleton_mm_data(prompt["encoder_prompt"])
|
||||
def _extract_mm_data(
|
||||
self, prompt: PromptType | DictPrompt | TokPrompt
|
||||
) -> MultiModalDataDict | None:
|
||||
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
|
||||
return self._extract_singleton_mm_data(prompt["encoder_prompt"]) # type: ignore[typeddict-item]
|
||||
else:
|
||||
return self._extract_singleton_mm_data(prompt)
|
||||
|
||||
def _maybe_build_mm_uuids(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
prompt: PromptType | DictPrompt | TokPrompt,
|
||||
) -> MultiModalUUIDDict | None:
|
||||
"""Build per-item multimodal hash overrides when enabled. In this case,
|
||||
multimodal data items are identified by their request id, modality and
|
||||
@@ -519,7 +521,7 @@ class InputProcessor:
|
||||
def process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
prompt: PromptType | DictPrompt | TokPrompt,
|
||||
params: SamplingParams | PoolingParams,
|
||||
arrival_time: float | None = None,
|
||||
lora_request: LoRARequest | None = None,
|
||||
|
||||
@@ -22,6 +22,8 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.renderers.inputs import DictPrompt, TokPrompt
|
||||
from vllm.renderers.inputs.preprocess import extract_prompt_components
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
@@ -32,7 +34,6 @@ from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.input_processor import InputProcessor
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.utils import get_prompt_text
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
|
||||
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
||||
@@ -216,7 +217,7 @@ class LLMEngine:
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: EngineCoreRequest | PromptType,
|
||||
prompt: EngineCoreRequest | PromptType | DictPrompt | TokPrompt,
|
||||
params: SamplingParams | PoolingParams,
|
||||
arrival_time: float | None = None,
|
||||
lora_request: LoRARequest | None = None,
|
||||
@@ -251,7 +252,7 @@ class LLMEngine:
|
||||
priority,
|
||||
supported_tasks=self.get_supported_tasks(),
|
||||
)
|
||||
prompt_text = get_prompt_text(prompt)
|
||||
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
|
||||
|
||||
self.input_processor.assign_request_id(request)
|
||||
|
||||
|
||||
@@ -17,8 +17,6 @@ import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.inputs.parse import get_prompt_components
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||
@@ -226,10 +224,6 @@ def get_device_indices(
|
||||
return value
|
||||
|
||||
|
||||
def get_prompt_text(prompt: PromptType) -> str | None:
|
||||
return get_prompt_components(prompt)[0]
|
||||
|
||||
|
||||
class CoreEngineActorManager:
|
||||
"""
|
||||
Utility class to handle creation, readiness, and shutdown
|
||||
|
||||
Reference in New Issue
Block a user