[V0 Deprecation] Remove Prompt Adapters (#20588)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-07-23 19:36:48 -04:00
committed by GitHub
parent 78c13e30e1
commit 82ec66f514
60 changed files with 126 additions and 1727 deletions

View File

@@ -14,7 +14,6 @@ API documentation for vLLM's configuration classes.
- [vllm.config.DeviceConfig][] - [vllm.config.DeviceConfig][]
- [vllm.config.SpeculativeConfig][] - [vllm.config.SpeculativeConfig][]
- [vllm.config.LoRAConfig][] - [vllm.config.LoRAConfig][]
- [vllm.config.PromptAdapterConfig][]
- [vllm.config.MultiModalConfig][] - [vllm.config.MultiModalConfig][]
- [vllm.config.PoolerConfig][] - [vllm.config.PoolerConfig][]
- [vllm.config.DecodingConfig][] - [vllm.config.DecodingConfig][]

View File

@@ -34,23 +34,22 @@ th:not(:first-child) {
} }
</style> </style>
| Feature | [CP][chunked-prefill] | [APC](automatic_prefix_caching.md) | [LoRA](lora.md) | <abbr title="Prompt Adapter">prmpt adptr</abbr> | [SD](spec_decode.md) | CUDA graph | <abbr title="Pooling Models">pooling</abbr> | <abbr title="Encoder-Decoder Models">enc-dec</abbr> | <abbr title="Logprobs">logP</abbr> | <abbr title="Prompt Logprobs">prmpt logP</abbr> | <abbr title="Async Output Processing">async output</abbr> | multi-step | <abbr title="Multimodal Inputs">mm</abbr> | best-of | beam-search | | Feature | [CP][chunked-prefill] | [APC](automatic_prefix_caching.md) | [LoRA](lora.md) | [SD](spec_decode.md) | CUDA graph | <abbr title="Pooling Models">pooling</abbr> | <abbr title="Encoder-Decoder Models">enc-dec</abbr> | <abbr title="Logprobs">logP</abbr> | <abbr title="Prompt Logprobs">prmpt logP</abbr> | <abbr title="Async Output Processing">async output</abbr> | multi-step | <abbr title="Multimodal Inputs">mm</abbr> | best-of | beam-search |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| |---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| [CP][chunked-prefill] | ✅ | | | | | | | | | | | | | | | | [CP][chunked-prefill] | ✅ | | | | | | | | | | | | | | |
| [APC](automatic_prefix_caching.md) | ✅ | ✅ | | | | | | | | | | | | | | | [APC](automatic_prefix_caching.md) | ✅ | ✅ | | | | | | | | | | | | | |
| [LoRA](lora.md) | ✅ | ✅ | ✅ | | | | | | | | | | | | | | [LoRA](lora.md) | ✅ | ✅ | ✅ | | | | | | | | | | | | |
| <abbr title="Prompt Adapter">prmpt adptr</abbr> | ✅ | ✅ | ✅ | | | | | | | | | | | | | | [SD](spec_decode.md) | ✅ | ✅ | | | | | | | | | | | | |
| [SD](spec_decode.md) | ✅ | ✅ | | ✅ | ✅ | | | | | | | | | | | | CUDA graph | ✅ | ✅ | | ✅ | ✅ | | | | | | | | | |
| CUDA graph | | | | | | ✅ | | | | | | | | | | | <abbr title="Pooling Models">pooling</abbr> | | | | | | ✅ | | | | | | | | |
| <abbr title="Pooling Models">pooling</abbr> | ❌ | ❌ | ❌ | ❌ | | | ✅ | | | | | | | | | | <abbr title="Encoder-Decoder Models">enc-dec</abbr> | ❌ | [](gh-issue:7366) | ❌ | [](gh-issue:7366) | | | ✅ | | | | | | | |
| <abbr title="Encoder-Decoder Models">enc-dec</abbr> | | [](gh-issue:7366) | | | [](gh-issue:7366) | | ✅ | ✅ | | | | | | | | | <abbr title="Logprobs">logP</abbr> | | | | | | | ✅ | ✅ | | | | | | |
| <abbr title="Logprobs">logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | | <abbr title="Prompt Logprobs">prmpt logP</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | |
| <abbr title="Prompt Logprobs">prmpt logP</abbr> | ✅ | ✅ | ✅ | | ✅ | | ❌ | ✅ | ✅ | ✅ | | | | | | | <abbr title="Async Output Processing">async output</abbr> | ✅ | ✅ | ✅ | | ✅ | | ❌ | ✅ | ✅ | ✅ | | | | |
| <abbr title="Async Output Processing">async output</abbr> | ✅ | ✅ | ✅ | | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | | | multi-step | ❌ | ✅ | | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | |
| multi-step | | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | | | <abbr title="Multimodal Inputs">mm</abbr> | | [🟠](gh-pr:8348) | [🟠](gh-pr:4194) | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | |
| <abbr title="Multimodal Inputs">mm</abbr> | ✅ | [🟠](gh-pr:8348) | [🟠](gh-pr:4194) | ❔ | ❔ | | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | | best-of | ✅ | ✅ | ✅ | [](gh-issue:6137) | ✅ | | ✅ | ✅ | ✅ | ❔ | [](gh-issue:7968) | ✅ | ✅ | |
| best-of | ✅ | ✅ | ✅ | ✅ | [](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](gh-issue:7968) | | ✅ | | | beam-search | ✅ | ✅ | ✅ | [](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](gh-issue:7968) | | ✅ | |
| beam-search | ✅ | ✅ | ✅ | ✅ | [](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](gh-issue:7968) | ❔ | ✅ | ✅ |
[](){ #feature-x-hardware } [](){ #feature-x-hardware }
@@ -61,7 +60,6 @@ th:not(:first-child) {
| [CP][chunked-prefill] | [](gh-issue:2729) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [CP][chunked-prefill] | [](gh-issue:2729) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [APC](automatic_prefix_caching.md) | [](gh-issue:3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [APC](automatic_prefix_caching.md) | [](gh-issue:3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| <abbr title="Prompt Adapter">prmpt adptr</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | [](gh-issue:8475) | ✅ | ❌ |
| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | | CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ |
| <abbr title="Pooling Models">pooling</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ❌ | | <abbr title="Pooling Models">pooling</abbr> | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ❌ |

View File

@@ -72,7 +72,6 @@ line-length = 80
"vllm/core/**/*.py" = ["UP006", "UP035"] "vllm/core/**/*.py" = ["UP006", "UP035"]
"vllm/engine/**/*.py" = ["UP006", "UP035"] "vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
"vllm/worker/**/*.py" = ["UP006", "UP035"] "vllm/worker/**/*.py" = ["UP006", "UP035"]
# Python 3.8 typing - skip utils for ROCm # Python 3.8 typing - skip utils for ROCm
"vllm/utils/__init__.py" = ["UP006", "UP035"] "vllm/utils/__init__.py" = ["UP006", "UP035"]

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# imports for guided decoding tests # imports for guided decoding tests
import json import json
import os
import shutil import shutil
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Optional from typing import Optional
@@ -26,10 +27,6 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically these adapters use a different base model, # technically these adapters use a different base model,
# but we're not testing generation quality here # but we're not testing generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora" LORA_NAME = "typeof/zephyr-7b-beta-lora"
PA_NAME = "swapnilbp/llama_tweet_ptune"
# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
# need to change to match the prompt adapter
PA_NUM_VIRTUAL_TOKENS = 8
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
@@ -56,13 +53,7 @@ def zephyr_lora_added_tokens_files(zephyr_lora_files):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def zephyr_pa_files(): def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files):
return snapshot_download(repo_id=PA_NAME)
@pytest.fixture(scope="module")
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
zephyr_pa_files):
return [ return [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
@@ -81,15 +72,6 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
"64", "64",
"--max-cpu-loras", "--max-cpu-loras",
"2", "2",
# pa config
"--enable-prompt-adapter",
"--prompt-adapters",
f"zephyr-pa={zephyr_pa_files}",
f"zephyr-pa2={zephyr_pa_files}",
"--max-prompt-adapters",
"2",
"--max-prompt-adapter-token",
"128",
] ]
@@ -98,8 +80,19 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
def server(default_server_args, request): def server(default_server_args, request):
if request.param: if request.param:
default_server_args.append(request.param) default_server_args.append(request.param)
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
original_value = os.environ.get('VLLM_USE_V1')
os.environ['VLLM_USE_V1'] = '0'
try:
with RemoteOpenAIServer(MODEL_NAME,
default_server_args) as remote_server:
yield remote_server yield remote_server
finally:
# Restore original env value
if original_value is None:
os.environ.pop('VLLM_USE_V1', None)
else:
os.environ['VLLM_USE_V1'] = original_value
@pytest_asyncio.fixture @pytest_asyncio.fixture
@@ -110,14 +103,11 @@ async def client(server):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters # first test base model, then test loras
"model_name,num_virtual_tokens", "model_name",
[(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0), [MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
) )
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
num_virtual_tokens: int):
completion = await client.completions.create(model=model_name, completion = await client.completions.create(model=model_name,
prompt="Hello, my name is", prompt="Hello, my name is",
max_tokens=5, max_tokens=5,
@@ -130,9 +120,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
assert len(choice.text) >= 5 assert len(choice.text) >= 5
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage( assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, completion_tokens=5, prompt_tokens=6, total_tokens=11)
prompt_tokens=6 + num_virtual_tokens,
total_tokens=11 + num_virtual_tokens)
# test using token IDs # test using token IDs
completion = await client.completions.create( completion = await client.completions.create(
@@ -175,9 +163,9 @@ async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters # first test base model, then test loras
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"], [MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
) )
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs # test using token IDs
@@ -194,9 +182,9 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# just test 1 lora and 1 pa hereafter # just test 1 lora
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs # test using token IDs
@@ -217,7 +205,7 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs # test using token IDs
@@ -238,7 +226,7 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
@@ -314,7 +302,7 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_completion_streaming(client: openai.AsyncOpenAI, async def test_completion_streaming(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
@@ -348,7 +336,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling. """Streaming for parallel sampling.
@@ -382,7 +370,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_completion_stream_options(client: openai.AsyncOpenAI, async def test_completion_stream_options(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
@@ -519,7 +507,7 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
# test both text and token IDs # test both text and token IDs

View File

@@ -13,7 +13,6 @@ from ...utils import RemoteOpenAIServer
from .test_completion import default_server_args # noqa: F401 from .test_completion import default_server_args # noqa: F401
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401 from .test_completion import zephyr_lora_files # noqa: F401
from .test_completion import zephyr_pa_files # noqa: F401
from .test_completion import MODEL_NAME from .test_completion import MODEL_NAME

View File

@@ -32,8 +32,7 @@ async def _async_serving_models_init() -> OpenAIServingModels:
serving_models = OpenAIServingModels(engine_client=mock_engine_client, serving_models = OpenAIServingModels(engine_client=mock_engine_client,
base_model_paths=BASE_MODEL_PATHS, base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config, model_config=mock_model_config,
lora_modules=None, lora_modules=None)
prompt_adapters=None)
await serving_models.init_static_loras() await serving_models.init_static_loras()
return serving_models return serving_models

View File

@@ -1,48 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import vllm
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "bigscience/bloomz-560m"
PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
def do_sample(llm, pa_name: str, pa_id: int):
prompts = [
"Tweet text : @nationalgridus I have no water and the bill is \
current and paid. Can you do something about this? Label : ",
"Tweet text : @nationalgridus Looks good thanks! Label : "
]
sampling_params = vllm.SamplingParams(temperature=0.0,
max_tokens=3,
stop_token_ids=[3])
outputs = llm.generate(prompts,
sampling_params,
prompt_adapter_request=PromptAdapterRequest(
pa_name, pa_id, PA_PATH, 8) if pa_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_twitter_prompt_adapter(enforce_eager: bool):
llm = vllm.LLM(MODEL_PATH,
enforce_eager=enforce_eager,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
expected_output = ['complaint', 'no complaint']
assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output

View File

@@ -1,56 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "bigscience/bloomz-560m"
pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
pa_path2 = 'swapnilbp/angry_tweet_ptune'
def do_sample(engine):
prompts = [
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech", 1, pa_path2, 8)),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)),
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3), None),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("complain", 3, pa_path, 8)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_multi_prompt_adapters():
engine_args = EngineArgs(model=MODEL_PATH,
max_prompt_adapters=3,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
expected_output = {
' quot;I', 'hate speech', 'no complaint', 'not hate speech'
}
assert do_sample(engine) == expected_output

View File

@@ -1,64 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune")
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
def do_sample(engine):
prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501
# first prompt with a prompt adapter and second without adapter
prompts = [
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]),
PromptAdapterRequest("hate_speech", 1, pa_path,
8), LoRARequest("sql_test", 1, lora_path)),
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]), None,
LoRARequest("sql_test", 1, lora_path)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request, lora_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request,
lora_request=lora_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_lora_prompt_adapter():
engine_args = EngineArgs(model=MODEL_PATH,
enable_prompt_adapter=True,
enable_lora=True,
max_num_seqs=60,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
result = do_sample(engine)
expected_output = {
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501
}
assert result == expected_output

View File

@@ -31,6 +31,5 @@ run_mypy vllm/inputs
run_mypy vllm/lora run_mypy vllm/lora
run_mypy vllm/model_executor run_mypy vllm/model_executor
run_mypy vllm/plugins run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/worker run_mypy vllm/worker
run_mypy vllm/v1 run_mypy vllm/v1

View File

@@ -3143,59 +3143,6 @@ class LoRAConfig:
self.lora_dtype = getattr(torch, self.lora_dtype) self.lora_dtype = getattr(torch, self.lora_dtype)
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class PromptAdapterConfig:
"""Configuration for PromptAdapters."""
max_prompt_adapters: int = 1
"""Max number of PromptAdapters in a batch."""
max_prompt_adapter_token: int = 0
"""Max number of PromptAdapters tokens."""
max_cpu_prompt_adapters: Optional[int] = None
"""Maximum number of PromptAdapters to store in CPU memory. Must be >= than
`max_prompt_adapters`."""
prompt_adapter_dtype: Union[torch.dtype, str] = "auto"
"""Data type for PromptAdapter. If auto, will default to base model dtype.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if self.max_prompt_adapters < 1:
raise ValueError(f"max_prompt_adapters "
f"({self.max_prompt_adapters}) must be >= 1.")
if self.max_prompt_adapter_token == 0:
raise ValueError("max_prompt_adapter_token must be set.")
if self.max_cpu_prompt_adapters is None:
self.max_cpu_prompt_adapters = self.max_prompt_adapters
def verify_with_model_config(self, model_config: ModelConfig):
if self.prompt_adapter_dtype == "auto":
self.prompt_adapter_dtype = model_config.dtype
elif isinstance(self.prompt_adapter_dtype, str):
self.prompt_adapter_dtype = getattr(torch,
self.prompt_adapter_dtype)
@config @config
@dataclass @dataclass
class MultiModalConfig: class MultiModalConfig:
@@ -4402,8 +4349,6 @@ class VllmConfig:
"""Decoding configuration.""" """Decoding configuration."""
observability_config: Optional[ObservabilityConfig] = None observability_config: Optional[ObservabilityConfig] = None
"""Observability configuration.""" """Observability configuration."""
prompt_adapter_config: Optional[PromptAdapterConfig] = None
"""Prompt adapter configuration."""
quant_config: Optional[QuantizationConfig] = None quant_config: Optional[QuantizationConfig] = None
"""Quantization configuration.""" """Quantization configuration."""
compilation_config: CompilationConfig = field( compilation_config: CompilationConfig = field(
@@ -4500,10 +4445,6 @@ class VllmConfig:
vllm_factors.append(self.observability_config.compute_hash()) vllm_factors.append(self.observability_config.compute_hash())
else: else:
vllm_factors.append("None") vllm_factors.append("None")
if self.prompt_adapter_config:
vllm_factors.append(self.prompt_adapter_config.compute_hash())
else:
vllm_factors.append("None")
if self.quant_config: if self.quant_config:
pass # should be captured by model_config.quantization pass # should be captured by model_config.quantization
if self.compilation_config: if self.compilation_config:
@@ -4611,9 +4552,6 @@ class VllmConfig:
if self.lora_config is not None: if self.lora_config is not None:
self.lora_config.verify_with_cache_config(self.cache_config) self.lora_config.verify_with_cache_config(self.cache_config)
self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_model_config(self.model_config)
if self.prompt_adapter_config is not None:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
if self.quant_config is None and self.model_config is not None: if self.quant_config is None and self.model_config is not None:
self.quant_config = VllmConfig._get_quantization_config( self.quant_config = VllmConfig._get_quantization_config(

View File

@@ -15,7 +15,6 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupBase, SequenceGroupMetadata, SequenceGroupBase, SequenceGroupMetadata,
SequenceGroupMetadataDelta, SequenceStage, SequenceGroupMetadataDelta, SequenceStage,
@@ -165,8 +164,6 @@ class SchedulerOutputs:
if self.num_loras > 0: if self.num_loras > 0:
self._sort_by_lora_ids() self._sort_by_lora_ids()
self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
def is_empty(self) -> bool: def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups. # NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
@@ -194,14 +191,6 @@ class SchedulerOutputs:
if g.seq_group.lora_request is not None if g.seq_group.lora_request is not None
} }
@property
def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
return {
g.seq_group.prompt_adapter_request
for g in self.scheduled_seq_groups
if g.seq_group.prompt_adapter_request is not None
}
@dataclass @dataclass
class SchedulerRunningOutputs: class SchedulerRunningOutputs:
@@ -1648,7 +1637,6 @@ class Scheduler:
multi_modal_placeholders=( multi_modal_placeholders=(
seq_group.multi_modal_placeholders seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None), if scheduler_outputs.num_prefill_groups > 0 else None),
prompt_adapter_request=seq_group.prompt_adapter_request,
) )
else: else:
# When SPMD mode is enabled, we only send delta data except for # When SPMD mode is enabled, we only send delta data except for

View File

@@ -30,9 +30,9 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
LogprobsMode, LoRAConfig, ModelConfig, ModelDType, LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
ModelImpl, MultiModalConfig, ObservabilityConfig, ModelImpl, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy, SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
SpeculativeConfig, TaskOption, TokenizerMode, TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
VllmConfig, get_attr_docs, get_field) get_field)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
@@ -358,11 +358,6 @@ class EngineArgs:
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
# PromptAdapter fields
enable_prompt_adapter: bool = False
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
max_prompt_adapter_token: int = \
PromptAdapterConfig.max_prompt_adapter_token
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
@@ -437,6 +432,8 @@ class EngineArgs:
ParallelConfig.enable_multimodal_encoder_data_parallel ParallelConfig.enable_multimodal_encoder_data_parallel
async_scheduling: bool = SchedulerConfig.async_scheduling async_scheduling: bool = SchedulerConfig.async_scheduling
# DEPRECATED
enable_prompt_adapter: bool = False
def __post_init__(self): def __post_init__(self):
# support `EngineArgs(compilation_config={...})` # support `EngineArgs(compilation_config={...})`
@@ -729,23 +726,6 @@ class EngineArgs:
lora_group.add_argument("--default-mm-loras", lora_group.add_argument("--default-mm-loras",
**lora_kwargs["default_mm_loras"]) **lora_kwargs["default_mm_loras"])
# PromptAdapter related configs
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
prompt_adapter_group = parser.add_argument_group(
title="PromptAdapterConfig",
description=PromptAdapterConfig.__doc__,
)
prompt_adapter_group.add_argument(
"--enable-prompt-adapter",
action=argparse.BooleanOptionalAction,
help="If True, enable handling of PromptAdapters.")
prompt_adapter_group.add_argument(
"--max-prompt-adapters",
**prompt_adapter_kwargs["max_prompt_adapters"])
prompt_adapter_group.add_argument(
"--max-prompt-adapter-token",
**prompt_adapter_kwargs["max_prompt_adapter_token"])
# Speculative arguments # Speculative arguments
speculative_group = parser.add_argument_group( speculative_group = parser.add_argument_group(
title="SpeculativeConfig", title="SpeculativeConfig",
@@ -850,6 +830,12 @@ class EngineArgs:
parser.add_argument('--disable-log-stats', parser.add_argument('--disable-log-stats',
action='store_true', action='store_true',
help='Disable logging statistics.') help='Disable logging statistics.')
parser.add_argument('--enable-prompt-adapter',
action='store_true',
deprecated=True,
help='[DEPRECATED] Prompt adapter has been '
'removed. Setting this flag to True or False'
' has no effect on vLLM behavior.')
return parser return parser
@@ -1234,11 +1220,6 @@ class EngineArgs:
load_config = self.create_load_config() load_config = self.create_load_config()
prompt_adapter_config = PromptAdapterConfig(
max_prompt_adapters=self.max_prompt_adapters,
max_prompt_adapter_token=self.max_prompt_adapter_token) \
if self.enable_prompt_adapter else None
decoding_config = DecodingConfig( decoding_config = DecodingConfig(
backend=self.guided_decoding_backend, backend=self.guided_decoding_backend,
disable_fallback=self.guided_decoding_disable_fallback, disable_fallback=self.guided_decoding_disable_fallback,
@@ -1266,7 +1247,6 @@ class EngineArgs:
load_config=load_config, load_config=load_config,
decoding_config=decoding_config, decoding_config=decoding_config,
observability_config=observability_config, observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config,
compilation_config=self.compilation_config, compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config, kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config, kv_events_config=self.kv_events_config,
@@ -1342,12 +1322,6 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# No Prompt Adapter so far.
if self.enable_prompt_adapter:
_raise_or_fallback(feature_name="--enable-prompt-adapter",
recommend_to_remove=False)
return False
# No text embedding inputs so far. # No text embedding inputs so far.
if self.enable_prompt_embeds: if self.enable_prompt_embeds:
_raise_or_fallback(feature_name="--enable-prompt-embeds", _raise_or_fallback(feature_name="--enable-prompt-embeds",
@@ -1469,7 +1443,6 @@ class EngineArgs:
if (is_gpu and not use_sliding_window and not use_spec_decode if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora and not self.enable_lora
and not self.enable_prompt_adapter
and model_config.runner_type != "pooling"): and model_config.runner_type != "pooling"):
self.enable_chunked_prefill = True self.enable_chunked_prefill = True
logger.warning( logger.warning(

View File

@@ -29,7 +29,6 @@ from vllm.model_executor.guided_decoding import (
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
@@ -435,7 +434,6 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
@@ -468,7 +466,6 @@ class _AsyncLLMEngine(LLMEngine):
processed_inputs = await self.input_preprocessor.preprocess_async( processed_inputs = await self.input_preprocessor.preprocess_async(
prompt, prompt,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
@@ -491,7 +488,6 @@ class _AsyncLLMEngine(LLMEngine):
params=params, params=params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority, priority=priority,
) )
@@ -861,7 +857,6 @@ class AsyncLLMEngine(EngineClient):
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
@@ -889,7 +884,6 @@ class AsyncLLMEngine(EngineClient):
arrival_time=arrival_time or time.time(), arrival_time=arrival_time or time.time(),
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
@@ -904,7 +898,6 @@ class AsyncLLMEngine(EngineClient):
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
@@ -922,8 +915,6 @@ class AsyncLLMEngine(EngineClient):
request_id: The unique id of the request. request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: The priority of the request. priority: The priority of the request.
Only applicable with priority scheduling. Only applicable with priority scheduling.
data_parallel_rank: The (global) data parallel rank that must data_parallel_rank: The (global) data parallel rank that must
@@ -983,7 +974,6 @@ class AsyncLLMEngine(EngineClient):
sampling_params, sampling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
): ):

View File

@@ -44,7 +44,6 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.outputs import (PoolingRequestOutput, RequestOutput, from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
PoolingSequenceGroupOutput, Sequence, SequenceGroup, PoolingSequenceGroupOutput, Sequence, SequenceGroup,
@@ -223,7 +222,6 @@ class LLMEngine:
self.load_config = vllm_config.load_config self.load_config = vllm_config.load_config
self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
) )
self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
) )
@@ -294,8 +292,6 @@ class LLMEngine:
# Feature flags # Feature flags
"enable_lora": "enable_lora":
bool(self.lora_config), bool(self.lora_config),
"enable_prompt_adapter":
bool(self.prompt_adapter_config),
"enable_prefix_caching": "enable_prefix_caching":
self.cache_config.enable_prefix_caching, self.cache_config.enable_prefix_caching,
"enforce_eager": "enforce_eager":
@@ -542,9 +538,6 @@ class LLMEngine:
self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config( self.lora_config.verify_with_scheduler_config(
self.scheduler_config) self.scheduler_config)
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def _add_processed_request( def _add_processed_request(
self, self,
@@ -553,7 +546,6 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
) -> Optional[SequenceGroup]: ) -> Optional[SequenceGroup]:
@@ -569,7 +561,6 @@ class LLMEngine:
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
) )
return None return None
@@ -583,11 +574,10 @@ class LLMEngine:
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request) lora_request)
encoder_seq = (None if encoder_inputs is None else Sequence( encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request, seq_id, encoder_inputs, block_size, eos_token_id, lora_request))
prompt_adapter_request))
# Create a SequenceGroup based on SamplingParams or PoolingParams # Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams): if isinstance(params, SamplingParams):
@@ -598,7 +588,6 @@ class LLMEngine:
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq, encoder_seq=encoder_seq,
priority=priority) priority=priority)
elif isinstance(params, PoolingParams): elif isinstance(params, PoolingParams):
@@ -608,7 +597,6 @@ class LLMEngine:
params, params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq, encoder_seq=encoder_seq,
priority=priority) priority=priority)
else: else:
@@ -637,7 +625,6 @@ class LLMEngine:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> None: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
@@ -658,7 +645,6 @@ class LLMEngine:
the current monotonic time. the current monotonic time.
lora_request: The LoRA request to add. lora_request: The LoRA request to add.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: The prompt adapter request to add.
priority: The priority of the request. priority: The priority of the request.
Only applicable with priority scheduling. Only applicable with priority scheduling.
@@ -719,7 +705,6 @@ class LLMEngine:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
) )
self._add_processed_request( self._add_processed_request(
@@ -728,7 +713,6 @@ class LLMEngine:
params=params, params=params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority, priority=priority,
) )
@@ -741,7 +725,6 @@ class LLMEngine:
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
priority: int = 0, priority: int = 0,
) -> SequenceGroup: ) -> SequenceGroup:
@@ -769,14 +752,12 @@ class LLMEngine:
if self.vllm_config.speculative_config is not None: if self.vllm_config.speculative_config is not None:
draft_size = \ draft_size = \
self.vllm_config.speculative_config.num_speculative_tokens + 1 self.vllm_config.speculative_config.num_speculative_tokens + 1
seq_group = SequenceGroup( seq_group = SequenceGroup(request_id=request_id,
request_id=request_id,
seqs=[seq], seqs=[seq],
arrival_time=arrival_time, arrival_time=arrival_time,
sampling_params=sampling_params, sampling_params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq, encoder_seq=encoder_seq,
priority=priority, priority=priority,
draft_size=draft_size) draft_size=draft_size)
@@ -790,7 +771,6 @@ class LLMEngine:
pooling_params: PoolingParams, pooling_params: PoolingParams,
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
priority: int = 0, priority: int = 0,
) -> SequenceGroup: ) -> SequenceGroup:
@@ -798,13 +778,11 @@ class LLMEngine:
# Defensive copy of PoolingParams, which are used by the pooler # Defensive copy of PoolingParams, which are used by the pooler
pooling_params = pooling_params.clone() pooling_params = pooling_params.clone()
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup( seq_group = SequenceGroup(request_id=request_id,
request_id=request_id,
seqs=[seq], seqs=[seq],
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
pooling_params=pooling_params, pooling_params=pooling_params,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq, encoder_seq=encoder_seq,
priority=priority) priority=priority)
return seq_group return seq_group
@@ -1834,16 +1812,6 @@ class LLMEngine:
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id) return self.model_executor.pin_lora(lora_id)
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.model_executor.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> List[int]:
return self.model_executor.list_prompt_adapters()
def start_profile(self) -> None: def start_profile(self) -> None:
self.model_executor.start_profile() self.model_executor.start_profile()

View File

@@ -10,7 +10,6 @@ from vllm import PoolingParams
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import Device from vllm.utils import Device
@@ -33,7 +32,6 @@ class RPCProcessRequest:
request_id: str request_id: str
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
priority: int = 0 priority: int = 0
def __init__( def __init__(
@@ -43,7 +41,6 @@ class RPCProcessRequest:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> None: ) -> None:
super().__init__() super().__init__()
@@ -53,7 +50,6 @@ class RPCProcessRequest:
self.request_id = request_id self.request_id = request_id
self.lora_request = lora_request self.lora_request = lora_request
self.trace_headers = trace_headers self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
self.priority = priority self.priority = priority

View File

@@ -45,7 +45,6 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import Device from vllm.utils import Device
@@ -448,7 +447,6 @@ class MQLLMEngineClient(EngineClient):
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request. """Generate outputs for a request.
@@ -465,8 +463,6 @@ class MQLLMEngineClient(EngineClient):
request_id: The unique id of the request. request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: Priority of the request (lower means earlier handling). priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the Any priority other than 0 will lead to an error if the
scheduling policy is not "priority". scheduling policy is not "priority".
@@ -474,8 +470,7 @@ class MQLLMEngineClient(EngineClient):
return cast( return cast(
AsyncGenerator[RequestOutput, None], AsyncGenerator[RequestOutput, None],
self._process_request(prompt, sampling_params, request_id, self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers, lora_request, trace_headers, priority))
prompt_adapter_request, priority))
def encode( def encode(
self, self,
@@ -521,7 +516,6 @@ class MQLLMEngineClient(EngineClient):
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]: PoolingRequestOutput, None]]:
@@ -575,7 +569,6 @@ class MQLLMEngineClient(EngineClient):
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
)) ))

View File

@@ -304,13 +304,11 @@ class MQLLMEngine:
self._send_outputs(rpc_err) self._send_outputs(rpc_err)
try: try:
self.engine.add_request( self.engine.add_request(request_id=request_id,
request_id=request_id,
prompt=request.prompt, prompt=request.prompt,
params=request.params, params=request.params,
lora_request=request.lora_request, lora_request=request.lora_request,
trace_headers=request.trace_headers, trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
priority=request.priority) priority=request.priority)
if self.log_requests: if self.log_requests:

View File

@@ -16,7 +16,6 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Device, collect_from_async_generator, random_uuid from vllm.utils import Device, collect_from_async_generator, random_uuid
@@ -55,7 +54,6 @@ class EngineClient(ABC):
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.""" """Generate outputs for a request."""

View File

@@ -45,7 +45,6 @@ from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput, PoolingRequestOutput, RequestOutput,
ScoringRequestOutput) ScoringRequestOutput)
from vllm.pooling_params import PoolingParams, PoolingTask from vllm.pooling_params import PoolingParams, PoolingTask
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams) RequestOutputKind, SamplingParams)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
@@ -314,7 +313,6 @@ class LLM:
*, *,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None, GuidedDecodingRequest]] = None,
) -> list[RequestOutput]: ) -> list[RequestOutput]:
@@ -330,7 +328,6 @@ class LLM:
prompt_token_ids: Optional[list[int]] = None, prompt_token_ids: Optional[list[int]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None, GuidedDecodingRequest]] = None,
) -> list[RequestOutput]: ) -> list[RequestOutput]:
@@ -346,7 +343,6 @@ class LLM:
prompt_token_ids: Optional[list[list[int]]] = None, prompt_token_ids: Optional[list[list[int]]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None, GuidedDecodingRequest]] = None,
) -> list[RequestOutput]: ) -> list[RequestOutput]:
@@ -363,7 +359,6 @@ class LLM:
prompt_token_ids: list[int], prompt_token_ids: list[int],
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None, GuidedDecodingRequest]] = None,
) -> list[RequestOutput]: ) -> list[RequestOutput]:
@@ -380,7 +375,6 @@ class LLM:
prompt_token_ids: list[list[int]], prompt_token_ids: list[list[int]],
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None, GuidedDecodingRequest]] = None,
) -> list[RequestOutput]: ) -> list[RequestOutput]:
@@ -395,7 +389,6 @@ class LLM:
prompt_token_ids: Union[list[int], list[list[int]]], prompt_token_ids: Union[list[int], list[list[int]]],
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None, GuidedDecodingRequest]] = None,
) -> list[RequestOutput]: ) -> list[RequestOutput]:
@@ -415,7 +408,6 @@ class LLM:
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None, GuidedDecodingRequest]] = None,
priority: Optional[list[int]] = None, priority: Optional[list[int]] = None,
@@ -440,8 +432,6 @@ class LLM:
it is used to create the progress bar. it is used to create the progress bar.
If `False`, no progress bar is created. If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
priority: The priority of the requests, if any. priority: The priority of the requests, if any.
Only applicable when priority scheduling policy is enabled. Only applicable when priority scheduling policy is enabled.
@@ -507,7 +497,6 @@ class LLM:
params=sampling_params, params=sampling_params,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
guided_options=guided_options_request, guided_options=guided_options_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
priority=priority, priority=priority,
@@ -963,7 +952,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
@@ -980,7 +968,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
@@ -997,7 +984,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
@@ -1015,7 +1001,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
@@ -1033,7 +1018,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
@@ -1049,7 +1033,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
@@ -1070,7 +1053,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
@@ -1092,8 +1074,6 @@ class LLM:
it is used to create the progress bar. it is used to create the progress bar.
If `False`, no progress bar is created. If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
pooling_task: Override the pooling task to use. pooling_task: Override the pooling task to use.
Returns: Returns:
@@ -1150,7 +1130,6 @@ class LLM:
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
prompt_adapter_request=prompt_adapter_request,
) )
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
@@ -1167,7 +1146,6 @@ class LLM:
pooling_params: Optional[Union[PoolingParams, pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[EmbeddingRequestOutput]: ) -> list[EmbeddingRequestOutput]:
""" """
Generate an embedding vector for each prompt. Generate an embedding vector for each prompt.
@@ -1187,8 +1165,6 @@ class LLM:
it is used to create the progress bar. it is used to create the progress bar.
If `False`, no progress bar is created. If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns: Returns:
A list of `EmbeddingRequestOutput` objects containing the A list of `EmbeddingRequestOutput` objects containing the
@@ -1205,7 +1181,6 @@ class LLM:
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
pooling_params=pooling_params, pooling_params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
pooling_task="embed", pooling_task="embed",
) )
@@ -1218,7 +1193,6 @@ class LLM:
*, *,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ClassificationRequestOutput]: ) -> list[ClassificationRequestOutput]:
""" """
Generate class logits for each prompt. Generate class logits for each prompt.
@@ -1236,8 +1210,6 @@ class LLM:
it is used to create the progress bar. it is used to create the progress bar.
If `False`, no progress bar is created. If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns: Returns:
A list of `ClassificationRequestOutput` objects containing the A list of `ClassificationRequestOutput` objects containing the
@@ -1253,7 +1225,6 @@ class LLM:
prompts, prompts,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
pooling_task="classify", pooling_task="classify",
) )
@@ -1267,7 +1238,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]: ) -> list[ScoringRequestOutput]:
encoded_output: list[PoolingRequestOutput] = self.encode( encoded_output: list[PoolingRequestOutput] = self.encode(
@@ -1275,7 +1245,6 @@ class LLM:
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
pooling_task="embed", pooling_task="embed",
) )
@@ -1303,7 +1272,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]: ) -> list[ScoringRequestOutput]:
if isinstance(tokenizer, MistralTokenizer): if isinstance(tokenizer, MistralTokenizer):
@@ -1361,7 +1329,6 @@ class LLM:
params=pooling_params, params=pooling_params,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
) )
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
@@ -1381,7 +1348,6 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]: ) -> list[ScoringRequestOutput]:
"""Generate similarity scores for all pairs `<text,text_pair>` or """Generate similarity scores for all pairs `<text,text_pair>` or
`<multi-modal data, multi-modal data pair>`. `<multi-modal data, multi-modal data pair>`.
@@ -1412,8 +1378,6 @@ class LLM:
it is used to create the progress bar. it is used to create the progress bar.
If `False`, no progress bar is created. If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns: Returns:
A list of `ScoringRequestOutput` objects containing the A list of `ScoringRequestOutput` objects containing the
@@ -1504,8 +1468,7 @@ class LLM:
data_2, # type: ignore[arg-type] data_2, # type: ignore[arg-type]
truncate_prompt_tokens, truncate_prompt_tokens,
use_tqdm, use_tqdm,
lora_request, lora_request)
prompt_adapter_request)
else: else:
return self._embedding_score( return self._embedding_score(
tokenizer, tokenizer,
@@ -1513,8 +1476,7 @@ class LLM:
data_2, # type: ignore[arg-type] data_2, # type: ignore[arg-type]
truncate_prompt_tokens, truncate_prompt_tokens,
use_tqdm, use_tqdm,
lora_request, lora_request)
prompt_adapter_request)
def start_profile(self) -> None: def start_profile(self) -> None:
self.llm_engine.start_profile() self.llm_engine.start_profile()
@@ -1625,7 +1587,6 @@ class LLM:
*, *,
use_tqdm: Union[bool, Callable[..., tqdm]] = True, use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
guided_options: Optional[GuidedDecodingRequest] = None, guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[list[int]] = None, priority: Optional[list[int]] = None,
@@ -1671,7 +1632,6 @@ class LLM:
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request[i] if isinstance( lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request, lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request,
priority=priority[i] if priority else 0, priority=priority[i] if priority else 0,
) )
@@ -1681,7 +1641,6 @@ class LLM:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
@@ -1691,7 +1650,6 @@ class LLM:
params, params,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
) )

View File

@@ -8,7 +8,6 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -30,7 +29,6 @@ class RequestLogger:
params: Optional[Union[SamplingParams, PoolingParams, params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]], BeamSearchParams]],
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None: ) -> None:
max_log_len = self.max_log_len max_log_len = self.max_log_len
if max_log_len is not None: if max_log_len is not None:
@@ -44,7 +42,6 @@ class RequestLogger:
"Received request %s: prompt: %r, " "Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, " "params: %s, prompt_token_ids: %s, "
"prompt_embeds shape: %s, " "prompt_embeds shape: %s, "
"lora_request: %s, prompt_adapter_request: %s.", request_id, "lora_request: %s.", request_id, prompt, params, prompt_token_ids,
prompt, params, prompt_token_ids,
prompt_embeds.shape if prompt_embeds is not None else None, prompt_embeds.shape if prompt_embeds is not None else None,
lora_request, prompt_adapter_request) lora_request)

View File

@@ -1620,7 +1620,6 @@ async def init_app_state(
model_config=model_config, model_config=model_config,
base_model_paths=base_model_paths, base_model_paths=base_model_paths,
lora_modules=lora_modules, lora_modules=lora_modules,
prompt_adapters=args.prompt_adapters,
) )
await state.openai_serving_models.init_static_loras() await state.openai_serving_models.init_static_loras()
state.openai_serving_responses = OpenAIServingResponses( state.openai_serving_responses = OpenAIServingResponses(

View File

@@ -20,8 +20,7 @@ from vllm.config import config
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template) validate_chat_template)
from vllm.entrypoints.openai.serving_models import (LoRAModulePath, from vllm.entrypoints.openai.serving_models import LoRAModulePath
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@@ -65,27 +64,6 @@ class LoRAParserAction(argparse.Action):
setattr(namespace, self.dest, lora_list) setattr(namespace, self.dest, lora_list)
class PromptAdapterParserAction(argparse.Action):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
adapter_list: list[PromptAdapterPath] = []
for item in values:
name, path = item.split('=')
adapter_list.append(PromptAdapterPath(name, path))
setattr(namespace, self.dest, adapter_list)
@config @config
@dataclass @dataclass
class FrontendArgs: class FrontendArgs:
@@ -115,9 +93,6 @@ class FrontendArgs:
or JSON list format. Example (old format): `'name=path'` Example (new or JSON list format. Example (old format): `'name=path'` Example (new
format): `{\"name\": \"name\", \"path\": \"lora_path\", format): `{\"name\": \"name\", \"path\": \"lora_path\",
\"base_model_name\": \"id\"}`""" \"base_model_name\": \"id\"}`"""
prompt_adapters: Optional[list[PromptAdapterPath]] = None
"""Prompt adapter configurations in the format name=path. Multiple adapters
can be specified."""
chat_template: Optional[str] = None chat_template: Optional[str] = None
"""The file path to the chat template, or the template in single-line form """The file path to the chat template, or the template in single-line form
for the specified model.""" for the specified model."""
@@ -207,12 +182,6 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
frontend_kwargs["lora_modules"]["type"] = optional_type(str) frontend_kwargs["lora_modules"]["type"] = optional_type(str)
frontend_kwargs["lora_modules"]["action"] = LoRAParserAction frontend_kwargs["lora_modules"]["action"] = LoRAParserAction
# Special case: Prompt adapters need custom parser action and
# optional_type(str)
frontend_kwargs["prompt_adapters"]["type"] = optional_type(str)
frontend_kwargs["prompt_adapters"][
"action"] = PromptAdapterParserAction
# Special case: Middleware needs append action # Special case: Middleware needs append action
frontend_kwargs["middleware"]["action"] = "append" frontend_kwargs["middleware"]["action"] = "append"
frontend_kwargs["middleware"]["type"] = str frontend_kwargs["middleware"]["type"] = str
@@ -288,9 +257,6 @@ def validate_parsed_serve_args(args: argparse.Namespace):
if args.enable_auto_tool_choice and not args.tool_call_parser: if args.enable_auto_tool_choice and not args.tool_call_parser:
raise TypeError("Error: --enable-auto-tool-choice requires " raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser") "--tool-call-parser")
if args.enable_prompt_embeds and args.enable_prompt_adapter:
raise ValueError(
"Cannot use prompt embeds and prompt adapter at the same time.")
def log_non_default_args(args: argparse.Namespace): def log_non_default_args(args: argparse.Namespace):

View File

@@ -337,7 +337,6 @@ async def main(args):
model_config=model_config, model_config=model_config,
base_model_paths=base_model_paths, base_model_paths=base_model_paths,
lora_modules=None, lora_modules=None,
prompt_adapters=None,
) )
openai_serving_chat = OpenAIServingChat( openai_serving_chat = OpenAIServingChat(
engine, engine,

View File

@@ -147,11 +147,8 @@ class OpenAIServingChat(OpenAIServing):
raise self.engine_client.dead_error raise self.engine_client.dead_error
try: try:
( lora_request = self._maybe_get_adapters(
lora_request, request, supports_default_mm_loras=True)
prompt_adapter_request,
) = self._maybe_get_adapters(request,
supports_default_mm_loras=True)
model_name = self._get_model_name(request.model, lora_request) model_name = self._get_model_name(request.model, lora_request)
@@ -239,8 +236,7 @@ class OpenAIServingChat(OpenAIServing):
self._log_inputs(request_id, self._log_inputs(request_id,
request_prompts[i], request_prompts[i],
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request)
prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers)) self._get_trace_headers(raw_request.headers))
@@ -259,7 +255,6 @@ class OpenAIServingChat(OpenAIServing):
request_id, request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority, priority=request.priority,
) )

View File

@@ -49,19 +49,11 @@ class ClassificationMixin(OpenAIServing):
return None return None
try: try:
( ctx.lora_request = self._maybe_get_adapters(ctx.request)
ctx.lora_request,
ctx.prompt_adapter_request,
) = self._maybe_get_adapters(ctx.request)
ctx.tokenizer = await self.engine_client.get_tokenizer( ctx.tokenizer = await self.engine_client.get_tokenizer(
ctx.lora_request) ctx.lora_request)
if ctx.prompt_adapter_request is not None:
raise NotImplementedError(
"Prompt adapter is not supported for classification models"
)
( (
ctx.request_prompts, ctx.request_prompts,
ctx.engine_prompts, ctx.engine_prompts,

View File

@@ -121,10 +121,7 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
try: try:
( lora_request = self._maybe_get_adapters(request)
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
@@ -197,7 +194,6 @@ class OpenAIServingCompletion(OpenAIServing):
request_prompts[i], request_prompts[i],
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
) )
trace_headers = (None if raw_request is None else await trace_headers = (None if raw_request is None else await
@@ -221,7 +217,6 @@ class OpenAIServingCompletion(OpenAIServing):
sampling_params, sampling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=request.priority, priority=request.priority,
) )

View File

@@ -53,18 +53,11 @@ class EmbeddingMixin(OpenAIServing):
) -> Optional[ErrorResponse]: ) -> Optional[ErrorResponse]:
ctx = cast(EmbeddingServeContext, ctx) ctx = cast(EmbeddingServeContext, ctx)
try: try:
( ctx.lora_request = self._maybe_get_adapters(ctx.request)
ctx.lora_request,
ctx.prompt_adapter_request,
) = self._maybe_get_adapters(ctx.request)
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
) )
if ctx.prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for embedding models")
if isinstance(ctx.request, EmbeddingChatRequest): if isinstance(ctx.request, EmbeddingChatRequest):
( (
_, _,

View File

@@ -68,7 +68,6 @@ from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error
MultiModalDataDict) MultiModalDataDict)
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob, PromptLogprobs from vllm.sequence import Logprob, PromptLogprobs
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
@@ -161,7 +160,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
request_id: str request_id: str
created_time: int = Field(default_factory=lambda: int(time.time())) created_time: int = Field(default_factory=lambda: int(time.time()))
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
# Shared across most requests # Shared across most requests
tokenizer: Optional[AnyTokenizer] = None tokenizer: Optional[AnyTokenizer] = None
@@ -343,12 +341,10 @@ class OpenAIServing:
return self.create_error_response( return self.create_error_response(
"Request prompts not available") "Request prompts not available")
self._log_inputs( self._log_inputs(request_id_item,
request_id_item,
ctx.request_prompts[i], ctx.request_prompts[i],
params=pooling_params, params=pooling_params,
lora_request=ctx.lora_request, lora_request=ctx.lora_request)
prompt_adapter_request=ctx.prompt_adapter_request)
# Mypy has an existing bug related to inferring the variance of # Mypy has an existing bug related to inferring the variance of
# TypedDicts with `builtins.enumerate`: # TypedDicts with `builtins.enumerate`:
@@ -450,11 +446,6 @@ class OpenAIServing:
if isinstance(load_result, ErrorResponse) and \ if isinstance(load_result, ErrorResponse) and \
load_result.code == HTTPStatus.BAD_REQUEST.value: load_result.code == HTTPStatus.BAD_REQUEST.value:
error_response = load_result error_response = load_result
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.models.prompt_adapter_requests
]:
return None
return error_response or self.create_error_response( return error_response or self.create_error_response(
message=f"The model `{request.model}` does not exist.", message=f"The model `{request.model}` does not exist.",
@@ -489,25 +480,21 @@ class OpenAIServing:
self, self,
request: AnyRequest, request: AnyRequest,
supports_default_mm_loras: bool = False, supports_default_mm_loras: bool = False,
) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[ ) -> Optional[LoRARequest]:
None, PromptAdapterRequest]]:
if request.model in self.models.lora_requests: if request.model in self.models.lora_requests:
return self.models.lora_requests[request.model], None return self.models.lora_requests[request.model]
# Currently only support default modality specific loras # Currently only support default modality specific loras
# if we have exactly one lora matched on the request. # if we have exactly one lora matched on the request.
if supports_default_mm_loras: if supports_default_mm_loras:
default_mm_lora = self._get_active_default_mm_loras(request) default_mm_lora = self._get_active_default_mm_loras(request)
if default_mm_lora is not None: if default_mm_lora is not None:
return default_mm_lora, None return default_mm_lora
if self._is_model_supported(request.model): if self._is_model_supported(request.model):
return None, None return None
for prompt_adapter in self.models.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return None, prompt_adapter
# if _check_model has been called earlier, this will be unreachable # if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.") raise ValueError(f"The model `{request.model}` does not exist.")
@@ -987,7 +974,6 @@ class OpenAIServing:
params: Optional[Union[SamplingParams, PoolingParams, params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]], BeamSearchParams]],
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None: ) -> None:
if self.request_logger is None: if self.request_logger is None:
return return
@@ -1009,7 +995,6 @@ class OpenAIServing:
prompt_embeds, prompt_embeds,
params=params, params=params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
) )
async def _get_trace_headers( async def _get_trace_headers(

View File

@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pathlib
from asyncio import Lock from asyncio import Lock
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
@@ -19,7 +17,6 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.utils import AtomicCounter from vllm.utils import AtomicCounter
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -31,12 +28,6 @@ class BaseModelPath:
model_path: str model_path: str
@dataclass
class PromptAdapterPath:
name: str
local_path: str
@dataclass @dataclass
class LoRAModulePath: class LoRAModulePath:
name: str name: str
@@ -60,7 +51,6 @@ class OpenAIServingModels:
base_model_paths: list[BaseModelPath], base_model_paths: list[BaseModelPath],
*, *,
lora_modules: Optional[list[LoRAModulePath]] = None, lora_modules: Optional[list[LoRAModulePath]] = None,
prompt_adapters: Optional[list[PromptAdapterPath]] = None,
): ):
super().__init__() super().__init__()
@@ -81,20 +71,6 @@ class OpenAIServingModels:
LoRAResolverRegistry.get_resolver(lora_resolver_name)) LoRAResolverRegistry.get_resolver(lora_resolver_name))
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock) self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with pathlib.Path(prompt_adapter.local_path,
"adapter_config.json").open() as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
PromptAdapterRequest(
prompt_adapter_name=prompt_adapter.name,
prompt_adapter_id=i,
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
async def init_static_loras(self): async def init_static_loras(self):
"""Loads all static LoRA modules. """Loads all static LoRA modules.
Raises if any fail to load""" Raises if any fail to load"""
@@ -141,14 +117,7 @@ class OpenAIServingModels:
permission=[ModelPermission()]) permission=[ModelPermission()])
for lora in self.lora_requests.values() for lora in self.lora_requests.values()
] ]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.base_model_paths[0].name,
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
model_cards.extend(lora_cards) model_cards.extend(lora_cards)
model_cards.extend(prompt_adapter_cards)
return ModelList(data=model_cards) return ModelList(data=model_cards)
async def load_lora_adapter( async def load_lora_adapter(

View File

@@ -94,17 +94,10 @@ class OpenAIServingPooling(OpenAIServing):
try: try:
truncate_prompt_tokens = _validate_truncation_size( truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens) self.max_model_len, truncate_prompt_tokens)
( lora_request = self._maybe_get_adapters(request)
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for pooling models")
if isinstance(request, PoolingChatRequest): if isinstance(request, PoolingChatRequest):
( (
_, _,
@@ -153,8 +146,7 @@ class OpenAIServingPooling(OpenAIServing):
self._log_inputs(request_id_item, self._log_inputs(request_id_item,
request_prompts[i], request_prompts[i],
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request)
prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers)) self._get_trace_headers(raw_request.headers))

View File

@@ -133,10 +133,7 @@ class OpenAIServingResponses(OpenAIServing):
messages = self._construct_input_messages(request, prev_response) messages = self._construct_input_messages(request, prev_response)
try: try:
( lora_request = self._maybe_get_adapters(request)
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
model_name = self._get_model_name(request.model, lora_request) model_name = self._get_model_name(request.model, lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
@@ -169,8 +166,7 @@ class OpenAIServingResponses(OpenAIServing):
self._log_inputs(request.request_id, self._log_inputs(request.request_id,
request_prompts[i], request_prompts[i],
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request)
prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers)) self._get_trace_headers(raw_request.headers))
@@ -181,7 +177,6 @@ class OpenAIServingResponses(OpenAIServing):
request.request_id, request.request_id,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority, priority=request.priority,
) )
generators.append(generator) generators.append(generator)

View File

@@ -27,7 +27,6 @@ from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import make_async, merge_async_iterators from vllm.utils import make_async, merge_async_iterators
@@ -58,8 +57,6 @@ class ServingScores(OpenAIServing):
request_id: str, request_id: str,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[Union[LoRARequest, None]] = None, lora_request: Optional[Union[LoRARequest, None]] = None,
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
None]] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
) -> Union[list[PoolingRequestOutput], ErrorResponse]: ) -> Union[list[PoolingRequestOutput], ErrorResponse]:
input_texts = texts_1 + texts_2 input_texts = texts_1 + texts_2
@@ -100,8 +97,7 @@ class ServingScores(OpenAIServing):
self._log_inputs(request_id_item, self._log_inputs(request_id_item,
input_texts[i], input_texts[i],
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request)
prompt_adapter_request=prompt_adapter_request)
generators.append( generators.append(
self.engine_client.encode( self.engine_client.encode(
@@ -176,8 +172,6 @@ class ServingScores(OpenAIServing):
request_id: str, request_id: str,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[Union[LoRARequest, None]] = None, lora_request: Optional[Union[LoRARequest, None]] = None,
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
None]] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
) -> Union[list[PoolingRequestOutput], ErrorResponse]: ) -> Union[list[PoolingRequestOutput], ErrorResponse]:
request_prompts: list[str] = [] request_prompts: list[str] = []
@@ -261,8 +255,7 @@ class ServingScores(OpenAIServing):
self._log_inputs(request_id_item, self._log_inputs(request_id_item,
request_prompts[i], request_prompts[i],
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request)
prompt_adapter_request=prompt_adapter_request)
generator = self.engine_client.encode( generator = self.engine_client.encode(
engine_prompt, engine_prompt,
@@ -295,14 +288,7 @@ class ServingScores(OpenAIServing):
raw_request: Optional[Request] = None, raw_request: Optional[Request] = None,
truncate_prompt_tokens: Optional[int] = None, truncate_prompt_tokens: Optional[int] = None,
) -> Union[list[PoolingRequestOutput], ErrorResponse]: ) -> Union[list[PoolingRequestOutput], ErrorResponse]:
( lora_request = self._maybe_get_adapters(request)
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for scoring models")
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
@@ -340,7 +326,6 @@ class ServingScores(OpenAIServing):
request_id=request_id, request_id=request_id,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers) trace_headers=trace_headers)
else: else:
@@ -352,7 +337,6 @@ class ServingScores(OpenAIServing):
request_id=request_id, request_id=request_id,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers) trace_headers=trace_headers)
async def create_score( async def create_score(

View File

@@ -60,10 +60,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_id = f"tokn-{self._base_request_id(raw_request)}" request_id = f"tokn-{self._base_request_id(raw_request)}"
try: try:
( lora_request = self._maybe_get_adapters(request)
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
@@ -104,11 +101,8 @@ class OpenAIServingTokenization(OpenAIServing):
self._log_inputs(request_id, self._log_inputs(request_id,
request_prompts[i], request_prompts[i],
params=None, params=None,
lora_request=lora_request, lora_request=lora_request)
prompt_adapter_request=prompt_adapter_request)
# Silently ignore prompt adapter since it does not affect
# tokenization (Unlike in Embeddings API where an error is raised)
if isinstance(engine_prompt, if isinstance(engine_prompt,
dict) and "prompt_token_ids" in engine_prompt: dict) and "prompt_token_ids" in engine_prompt:
input_ids.extend(engine_prompt["prompt_token_ids"]) input_ids.extend(engine_prompt["prompt_token_ids"])
@@ -133,21 +127,14 @@ class OpenAIServingTokenization(OpenAIServing):
request_id = f"tokn-{self._base_request_id(raw_request)}" request_id = f"tokn-{self._base_request_id(raw_request)}"
( lora_request = self._maybe_get_adapters(request)
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
self._log_inputs(request_id, self._log_inputs(request_id,
request.tokens, request.tokens,
params=None, params=None,
lora_request=lora_request, lora_request=lora_request)
prompt_adapter_request=prompt_adapter_request)
# Silently ignore prompt adapter since it does not affect tokenization
# (Unlike in Embeddings API where an error is raised)
prompt_input = await self._tokenize_prompt_input_async( prompt_input = await self._tokenize_prompt_input_async(
request, request,

View File

@@ -150,19 +150,12 @@ class OpenAISpeechToText(OpenAIServing):
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
try: try:
( lora_request = self._maybe_get_adapters(request)
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
if lora_request: if lora_request:
return self.create_error_response( return self.create_error_response(
"Currently do not support LoRA for " "Currently do not support LoRA for "
f"{self.task_type.title()}.") f"{self.task_type.title()}.")
if prompt_adapter_request:
return self.create_error_response(
f"Currently do not support PromptAdapter for "
f"{self.task_type.title()}.")
prompts, duration_s = await self._preprocess_speech_to_text( prompts, duration_s = await self._preprocess_speech_to_text(
request=request, request=request,
@@ -188,8 +181,7 @@ class OpenAISpeechToText(OpenAIServing):
# It will not display special tokens like <|startoftranscript|> # It will not display special tokens like <|startoftranscript|>
request.prompt, request.prompt,
params=sampling_params, params=sampling_params,
lora_request=None, lora_request=None)
prompt_adapter_request=None)
list_result_generator = [ list_result_generator = [
self.engine_client.generate( self.engine_client.generate(

View File

@@ -17,7 +17,6 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.pooling_params import PoolingTask from vllm.pooling_params import PoolingTask
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
@@ -50,7 +49,6 @@ class ExecutorBase(ABC):
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self._init_executor() self._init_executor()
self.is_sleeping = False self.is_sleeping = False
@@ -171,35 +169,6 @@ class ExecutorBase(ABC):
assert s == sets[0], "All workers should have the same LORAs." assert s == sets[0], "All workers should have the same LORAs."
return sets[0] return sets[0]
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
assert prompt_adapter_request.prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return all(
self.collective_rpc("add_prompt_adapter",
args=(prompt_adapter_request, )))
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return all(
self.collective_rpc("remove_prompt_adapter",
args=(prompt_adapter_id, )))
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return all(
self.collective_rpc("pin_prompt_adapter",
args=(prompt_adapter_id, )))
def list_prompt_adapters(self) -> Set[int]:
sets = self.collective_rpc("list_prompt_adapters")
for s in sets:
assert (s == sets[0]
), "All workers should have the same prompt adapters."
return sets[0]
def start_profile(self) -> None: def start_profile(self) -> None:
self.collective_rpc("start_profile") self.collective_rpc("start_profile")

View File

@@ -13,7 +13,6 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs) MultiModalInputs)
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@@ -168,18 +167,6 @@ class InputPreprocessor:
return decoder_input_ids return decoder_input_ids
def _apply_prompt_adapter(
self,
prompt_token_ids: list[int],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> list[int]:
if prompt_adapter_request:
prompt_token_ids = (
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+ prompt_token_ids)
return prompt_token_ids
def _get_tokenization_kw( def _get_tokenization_kw(
self, self,
overrides: Optional[dict[str, Any]] = None, overrides: Optional[dict[str, Any]] = None,
@@ -786,15 +773,10 @@ class InputPreprocessor:
def _build_decoder_only_llm_inputs( def _build_decoder_only_llm_inputs(
self, self,
prompt_inputs: DecoderOnlyInputs, prompt_inputs: DecoderOnlyInputs,
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
if "prompt_token_ids" in prompt_inputs: if "prompt_token_ids" in prompt_inputs:
prompt_inputs = cast(Union[TokenInputs, MultiModalInputs], prompt_inputs = cast(Union[TokenInputs, MultiModalInputs],
prompt_inputs) # Needed for mypy prompt_inputs) # Needed for mypy
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
prompt_inputs["prompt_token_ids"],
prompt_adapter_request=prompt_adapter_request,
)
return prompt_inputs return prompt_inputs
@@ -803,7 +785,6 @@ class InputPreprocessor:
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
@@ -815,7 +796,6 @@ class InputPreprocessor:
* prompt: input prompt * prompt: input prompt
* lora_request * lora_request
* prompt_adapter_request
* return_mm_hashes * return_mm_hashes
Returns: Returns:
@@ -830,17 +810,13 @@ class InputPreprocessor:
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
return self._build_decoder_only_llm_inputs( return self._build_decoder_only_llm_inputs(prompt_comps)
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
async def _process_decoder_only_prompt_async( async def _process_decoder_only_prompt_async(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
@@ -854,17 +830,13 @@ class InputPreprocessor:
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
return self._build_decoder_only_llm_inputs( return self._build_decoder_only_llm_inputs(prompt_comps)
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
def preprocess( def preprocess(
self, self,
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> ProcessorInputs: ) -> ProcessorInputs:
"""Preprocess the input prompt.""" """Preprocess the input prompt."""
@@ -886,7 +858,6 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
@@ -895,7 +866,6 @@ class InputPreprocessor:
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> ProcessorInputs: ) -> ProcessorInputs:
""" """
@@ -919,6 +889,5 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )

View File

@@ -1,83 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from vllm.adapter_commons.layers import AdapterMapping
from vllm.config import PromptAdapterConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@dataclass
class PromptAdapterMapping(AdapterMapping):
pass
class VocabParallelEmbeddingWithPromptAdapter(nn.Module):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.emb_layer = self.base_layer
if 'LoRA' in base_layer.__class__.__name__:
self.emb_layer = self.base_layer.base_layer
def create_prompt_adapter_weights(
self, prompt_adapter_config: PromptAdapterConfig):
self.embeddings_tensors = torch.zeros(
(
prompt_adapter_config.max_prompt_adapters,
prompt_adapter_config.max_prompt_adapter_token,
self.emb_layer.embedding_dim,
),
dtype=self.emb_layer.weight.dtype,
device=self.emb_layer.weight.device,
)
self.adapter_lengths = torch.zeros(
prompt_adapter_config.max_prompt_adapters,
dtype=torch.long,
device=self.emb_layer.weight.device)
self.indices_gpu: torch.Tensor
self.embedding_indices_gpu: torch.Tensor
def reset_prompt_adapter(self, index: int):
self.embeddings_tensors[index] = 0
def set_prompt_adapter(
self,
index: int,
adapter_model: Optional[torch.Tensor],
):
self.reset_prompt_adapter(index)
if adapter_model is not None:
length = adapter_model.shape[0]
self.embeddings_tensors[index, :length] = adapter_model
self.adapter_lengths[index] = length
def set_mapping(
self,
prompt_indices: torch.Tensor,
prompt_embedding_indices: torch.Tensor,
):
self.indices_gpu = prompt_indices.to(
device=self.emb_layer.weight.device)
self.embedding_indices_gpu = prompt_embedding_indices.to(
device=self.emb_layer.weight.device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden_states = self.base_layer(x)
if self.embedding_indices_gpu.ndim > 1:
valid_mask = self.indices_gpu != -1
gathered_embeddings = self.embeddings_tensors[
self.embedding_indices_gpu[:, 0],
self.embedding_indices_gpu[:, 1]]
# Update hidden states
hidden_states[valid_mask] = gathered_embeddings
return hidden_states

View File

@@ -1,358 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import math
from typing import Any, Callable, Dict, List, Optional, Type
import torch
from torch import nn
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
get_adapter, list_adapters,
remove_adapter, set_adapter_mapping)
from vllm.config import PromptAdapterConfig
from vllm.prompt_adapter.layers import (
VocabParallelEmbeddingWithPromptAdapter) # yapf: disable
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.utils import load_peft_weights
logger = logging.getLogger(__name__)
_GLOBAL_PROMPT_ADAPTER_ID = 0
def get_prompt_adapter_id():
global _GLOBAL_PROMPT_ADAPTER_ID
_GLOBAL_PROMPT_ADAPTER_ID += 1
return _GLOBAL_PROMPT_ADAPTER_ID
def convert_to_embedding_indices(indices):
embedding_indices = []
count = 0
for value in indices:
if value == -1:
count = 0
else:
embedding_indices.append([value, count])
count += 1
return torch.tensor(embedding_indices)
def convert_mapping(
mapping: PromptAdapterMapping,
prompt_adapter_index_to_id: List[Optional[int]],
) -> torch.Tensor:
"""Converts PromptAdapterMapping to index tensors.
Args:
mapping: PromptAdapterMapping mapping rows in a
batch to PromptAdapter ids.
prompt_adapter_index_to_id: List mapping PromptAdapter
ids to PromptAdapter indices.
Returns:
pa_indices: Tensor of shape [batch_size] mapping batch rows to
PromptAdapter indices.
"""
id_to_index = {
id_: idx
for idx, id_ in enumerate(prompt_adapter_index_to_id)
if id_ is not None
}
pa_indices = ([
id_to_index.get(id_, -1) if id_ > 0 else -1
for id_ in mapping.index_mapping
])
pa_embedding_mapping = convert_to_embedding_indices(pa_indices)
pa_indices = torch.tensor(pa_indices)
return pa_indices, pa_embedding_mapping
class PromptAdapterModel(AdapterModel):
def __init__(self,
prompt_adapter_id=None,
num_virtual_tokens=None,
prompt_embedding=None) -> None:
self.id = prompt_adapter_id
self.prompt_embedding = prompt_embedding
self.num_virtual_tokens = num_virtual_tokens
@classmethod
def from_local_checkpoint(
cls,
adapter_model_path: str,
prompt_adapter_id: int,
num_virtual_tokens: int,
config: PromptAdapterConfig,
device: str = "cuda",
) -> "PromptAdapterModel":
if num_virtual_tokens > config.max_prompt_adapter_token:
raise ValueError(
f'num_virtual_tokens ({num_virtual_tokens}) should be <= '
f'max_prompt_adapter_token({config.max_prompt_adapter_token})')
adapters_weights = load_peft_weights(adapter_model_path, device)
prompt_embedding = adapters_weights["prompt_embeddings"].to(
config.prompt_adapter_dtype)
return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding)
class PromptAdapterModelManager(AdapterModelManager):
"""A manager that manages multiple Prompt Adapter models."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
prompt_adapter_config: PromptAdapterConfig,
):
"""Create a PromptAdapterModel and adapter for a given model.
Args:
model: the model to be adapted.
max_num_seqs: the maximum number of sequences model can run in a
single batch.
max_num_batched_tokens: the maximum number of tokens model can run
in a single batch.
prompt_adapter_config: the PromptAdapter config,
"""
self.model: nn.Module = model
# Dict instead of a Set for compatibility with LRUCache.
self.prompt_adapter_index_to_id: List[
Optional[int]] = [None] * self.prompt_adapter_slots
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.prompt_adapter_config = prompt_adapter_config
self.model.prompt_adapter_manager = self
self.adapter_type = 'PromptAdapter'
self.base_indices = torch.tensor([-1])
self.base_embedding_indices = torch.tensor([])
self.modules: Dict[str, nn.Module] = {}
self._create_prompt_adapter_modules()
self._last_mapping: Optional[PromptAdapterMapping] = None
@property
def prompt_adapter_slots(self) -> int:
return self.prompt_adapter_config.max_prompt_adapters
@property
def adapter_slots(self) -> int:
return self.prompt_adapter_slots
@property
def capacity(self) -> int:
return self.prompt_adapter_config.max_cpu_prompt_adapters
def activate_adapter(
self,
prompt_adapter_id: int,
) -> bool:
"""Move PromptAdapter into a GPU buffer
to be used in the forward pass."""
if prompt_adapter_id in self._active_adapters:
return False
first_free_slot = next(
((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate(
self.prompt_adapter_index_to_id) if prompt_adapter_id is None),
None)
if first_free_slot is None:
raise ValueError("No free prompt_adapter slots")
index, _ = first_free_slot
self._active_adapters[prompt_adapter_id] = None
prompt_adapter_model = (self._registered_adapters[prompt_adapter_id])
logger.debug("Activating prompt_adapter. int id: %d, slot index: %d",
prompt_adapter_model.id, index)
self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id
for _, v in self.modules.items():
v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding)
return True
def _deactivate_adapter(self, prompt_adapter_id: int):
try:
index = self.prompt_adapter_index_to_id.index(prompt_adapter_id)
self.prompt_adapter_index_to_id[index] = None
for _, v in self.modules.items():
v.reset_prompt_adapter(index)
except ValueError:
pass
def _add_adapter(self, prompt_adapter: PromptAdapterModel):
self._registered_adapters[prompt_adapter.id] = prompt_adapter
def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
base_indices, base_embedding_indices = convert_mapping(
mapping, self.prompt_adapter_index_to_id)
for k, v in self.modules.items():
v.set_mapping(base_indices, base_embedding_indices)
def _create_prompt_adapter_modules(self):
for module_name, module in self.model.named_modules(
remove_duplicate=False):
if "VocabParallel" in module.__class__.__name__:
new_module = VocabParallelEmbeddingWithPromptAdapter(module)
new_module.create_prompt_adapter_weights(
self.prompt_adapter_config)
replaced_module = self.replace_submodule(
self.model, module_name, new_module)
self.register_module(module.__class__.__name__,
replaced_module)
replaced_module.set_mapping(self.base_indices,
self.base_embedding_indices)
break
def replace_submodule(self, model: nn.Module, module_name: str,
new_module: nn.Module) -> nn.Module:
"""Replace a submodule in a model with a new module."""
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
target_name = module_name.split(".")[-1]
setattr(parent, target_name, new_module)
return new_module
def register_module(self, module_name: str, module: nn.Module):
self.modules[module_name] = module
def pin_adapter(self, prompt_adapter_id: int) -> bool:
"""Pin a PromptAdapterModel in the manager cache."""
raise NotImplementedError(
"Pinning is not supported in PromptAdapterModelManager. "
"Use LRUCachePromptAdapterModelManager for pinning"
) # type: ignore
def remove_all_adapters(self):
"""Remove all PromptAdapterModel from the manager."""
self._registered_adapters.clear()
self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots
self._active_adapters.clear()
def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters,
self._deactivate_adapter)
def add_adapter(self, adapter: PromptAdapterModel) -> bool:
return add_adapter(adapter, self._registered_adapters, self.capacity,
self._add_adapter)
def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
self._set_adapter_mapping)
def remove_adapter(self, adapter_id: int) -> bool:
return remove_adapter(adapter_id, self._registered_adapters,
self.deactivate_adapter)
def list_adapters(self) -> Dict[int, Any]:
return list_adapters(self._registered_adapters)
def get_adapter(self, adapter_id: int) -> Optional[Any]:
return get_adapter(adapter_id, self._registered_adapters)
class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]):
def __init__(self, capacity: int,
deactivate_prompt_adapter_fn: Callable[[int], bool]):
super().__init__(capacity, deactivate_prompt_adapter_fn)
class LRUCachePromptAdapterModelManager(PromptAdapterModelManager):
"""A model manager that manages multiple prompt_adapters with LRU cache."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
prompt_adapter_config: PromptAdapterConfig,
):
self.prompt_adapter_config = prompt_adapter_config
super().__init__(model, max_num_seqs, max_num_batched_tokens,
prompt_adapter_config)
self._registered_adapters = PromptAdapterLRUCache(
self.capacity, self.deactivate_adapter)
self._active_adapters = PromptAdapterLRUCache(
self.prompt_adapter_slots, self._deactivate_adapter)
def list_adapters(self) -> Dict[int, PromptAdapterModel]:
"""List all registered PromptAdapterModel."""
return dict(self._registered_adapters.cache)
def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool:
"""Add a PromptAdapterModel to the manager."""
if prompt_adapter.id not in self._registered_adapters:
self._add_adapter(prompt_adapter)
was_added = True
else:
# We always touch to update the LRU cache order
self._registered_adapters.touch(prompt_adapter.id)
was_added = False
return was_added
def activate_adapter(
self,
prompt_adapter_id: int,
) -> bool:
if prompt_adapter_id not in self._active_adapters and len(
self._active_adapters) >= self.prompt_adapter_slots:
self._active_adapters.remove_oldest()
result = super().activate_adapter(prompt_adapter_id)
# We always touch to update the LRU cache order
self._active_adapters.touch(prompt_adapter_id)
return result
def remove_oldest_adapter(self) -> bool:
if len(self._registered_adapters) > 0:
self._registered_adapters.remove_oldest()
return True
return False
def pin_adapter(self, prompt_adapter_id: int) -> bool:
"""Pin a PromptAdapterModel in the manager cache."""
self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id)
self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id)
return True
def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int):
try:
self._registered_adapters.pin(prompt_adapter_id)
except ValueError as err:
raise ValueError(
"Pinning failed. "
f"Prompt Adapter {prompt_adapter_id} is not registered."
) from err
def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int):
if prompt_adapter_id not in self._active_adapters:
# move adapter to gpu if not already active
self.activate_adapter(prompt_adapter_id)
self._active_adapters.pin(prompt_adapter_id)
def create_prompt_adapter_manager(
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
prompt_adapter_config: PromptAdapterConfig,
prompt_adapter_manager_cls: Type[
PromptAdapterModelManager] = PromptAdapterModelManager,
**kwargs) -> PromptAdapterModelManager:
"""Create a PromptAdapterModel for a given model."""
prompt_adapter_manager = prompt_adapter_manager_cls(
model=model,
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
prompt_adapter_config=prompt_adapter_config,
**kwargs)
return prompt_adapter_manager

View File

@@ -1,37 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import msgspec
from vllm.adapter_commons.request import AdapterRequest
class PromptAdapterRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
frozen=True): # type: ignore[call-arg]
"""
Request for a Prompt adapter.
"""
__metaclass__ = AdapterRequest
prompt_adapter_name: str
prompt_adapter_id: int
prompt_adapter_local_path: str
prompt_adapter_num_virtual_tokens: int
def __hash__(self):
return super().__hash__()
@property
def adapter_id(self):
return self.prompt_adapter_id
@property
def name(self):
return self.prompt_adapter_name
@property
def local_path(self):
return self.prompt_adapter_local_path

View File

@@ -1,98 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# code borrowed from: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/utils/save_and_load.py#L420
import os
from typing import Optional
import torch
from huggingface_hub import file_exists, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from safetensors.torch import load_file as safe_load_file
from vllm.platforms import current_platform
WEIGHTS_NAME = "adapter_model.bin"
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"
# Get current device name based on available devices
def infer_device() -> str:
if current_platform.is_cuda_alike():
return "cuda"
return "cpu"
def load_peft_weights(model_id: str,
device: Optional[str] = None,
**hf_hub_download_kwargs) -> dict:
r"""
A helper method to load the PEFT weights from the HuggingFace Hub or locally
Args:
model_id (`str`):
The local path to the adapter weights or the name of the adapter to
load from the HuggingFace Hub.
device (`str`):
The device to load the weights onto.
hf_hub_download_kwargs (`dict`):
Additional arguments to pass to the `hf_hub_download` method when
loading from the HuggingFace Hub.
"""
path = (os.path.join(model_id, hf_hub_download_kwargs["subfolder"]) if
hf_hub_download_kwargs.get("subfolder") is not None else model_id)
if device is None:
device = infer_device()
if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)):
filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME)
use_safetensors = True
elif os.path.exists(os.path.join(path, WEIGHTS_NAME)):
filename = os.path.join(path, WEIGHTS_NAME)
use_safetensors = False
else:
token = hf_hub_download_kwargs.get("token")
if token is None:
token = hf_hub_download_kwargs.get("use_auth_token")
hub_filename = (os.path.join(hf_hub_download_kwargs["subfolder"],
SAFETENSORS_WEIGHTS_NAME)
if hf_hub_download_kwargs.get("subfolder") is not None
else SAFETENSORS_WEIGHTS_NAME)
has_remote_safetensors_file = file_exists(
repo_id=model_id,
filename=hub_filename,
revision=hf_hub_download_kwargs.get("revision"),
repo_type=hf_hub_download_kwargs.get("repo_type"),
token=token,
)
use_safetensors = has_remote_safetensors_file
if has_remote_safetensors_file:
# Priority 1: load safetensors weights
filename = hf_hub_download(
model_id,
SAFETENSORS_WEIGHTS_NAME,
**hf_hub_download_kwargs,
)
else:
try:
filename = hf_hub_download(model_id, WEIGHTS_NAME,
**hf_hub_download_kwargs)
except EntryNotFoundError:
raise ValueError( # noqa: B904
f"Can't find weights for {model_id} in {model_id} or \
in the Hugging Face Hub. "
f"Please check that the file {WEIGHTS_NAME} or \
{SAFETENSORS_WEIGHTS_NAME} is present at {model_id}.")
if use_safetensors:
adapters_weights = safe_load_file(filename, device=device)
else:
adapters_weights = torch.load(filename,
map_location=torch.device(device),
weights_only=True)
return adapters_weights

View File

@@ -1,179 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
from typing import Any, Optional, Set, Type
import torch
from vllm.adapter_commons.utils import (add_adapter_worker,
apply_adapters_worker,
list_adapters_worker,
set_active_adapters_worker)
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
from vllm.config import PromptAdapterConfig
from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager,
PromptAdapterModel,
PromptAdapterModelManager,
create_prompt_adapter_manager)
from vllm.prompt_adapter.request import PromptAdapterRequest
logger = logging.getLogger(__name__)
class WorkerPromptAdapterManager(AbstractWorkerManager):
"""WorkerPromptAdapterManager that manages
prompt_adapter models on the worker side.
Every request, the requested prompt_adapters will be
loaded (unless they are already loaded),
and every other prompt_adapter will be unloaded."""
_manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager
def __init__(
self,
max_num_seqs: int,
max_num_batched_tokens: int,
device: torch.device,
prompt_adapter_config: PromptAdapterConfig,
prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel
):
self._adapter_manager: PromptAdapterModelManager
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self._prompt_adapter_model_cls = prompt_adapter_model_cls
self.prompt_adapter_config = prompt_adapter_config
super().__init__(device)
@property
def is_enabled(self) -> bool:
return True
def create_prompt_adapter_manager(
self,
model: torch.nn.Module,
) -> Any:
prompt_adapter_manager = create_prompt_adapter_manager(
model,
max_num_seqs=self.max_num_seqs,
max_num_batched_tokens=self.max_num_batched_tokens,
prompt_adapter_config=self.prompt_adapter_config,
prompt_adapter_manager_cls=self._manager_cls,
)
self._adapter_manager = prompt_adapter_manager
return prompt_adapter_manager.model
def _load_adapter(
self, prompt_adapter_request: PromptAdapterRequest
) -> PromptAdapterModel:
try:
prompt_adapter = (
self._prompt_adapter_model_cls.from_local_checkpoint(
prompt_adapter_request.prompt_adapter_local_path,
prompt_adapter_id=prompt_adapter_request.prompt_adapter_id,
num_virtual_tokens=prompt_adapter_request.
prompt_adapter_num_virtual_tokens,
config=self.prompt_adapter_config,
device=str(self.device),
))
except Exception as e:
raise RuntimeError(
f"Loading prompt_adapter "
f"{prompt_adapter_request.prompt_adapter_local_path}"
f" failed") from e
return prompt_adapter
def add_dummy_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return True
def pin_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.pin_adapter(adapter_id)
def set_active_adapters(self, requests: Set[Any],
mapping: Optional[Any]) -> None:
set_active_adapters_worker(requests, mapping, self._apply_adapters,
self._adapter_manager.set_adapter_mapping)
def add_adapter(self, adapter_request: Any) -> bool:
return add_adapter_worker(adapter_request, self.list_adapters,
self._load_adapter,
self._adapter_manager.add_adapter,
self._adapter_manager.activate_adapter)
def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
apply_adapters_worker(adapter_requests, self.list_adapters,
self._adapter_manager.adapter_slots,
self.remove_adapter, self.add_adapter)
def remove_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.remove_adapter(adapter_id)
def remove_all_adapters(self):
self._adapter_manager.remove_all_adapters()
def list_adapters(self) -> Set[int]:
return list_adapters_worker(self._adapter_manager.list_adapters)
class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager):
"""WorkerPromptAdapterManager that manages
prompt_adapter models on the worker side.
Uses an LRU Cache. Every request, the requested
prompt_adapters will be loaded (unless they are already loaded)
and least recently used prompt_adapters will
be unloaded if the cache is above capacity."""
_prompt_adapter_manager_cls: Type[
LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager
def create_prompt_adapter_manager(
self,
model: torch.nn.Module,
) -> Any:
prompt_adapter_manager = create_prompt_adapter_manager(
model,
max_num_seqs=self.max_num_seqs,
max_num_batched_tokens=self.max_num_batched_tokens,
prompt_adapter_config=self.prompt_adapter_config,
prompt_adapter_manager_cls=self._prompt_adapter_manager_cls)
self._adapter_manager: LRUCachePromptAdapterModelManager = (
prompt_adapter_manager)
return prompt_adapter_manager.model
def _apply_adapters(
self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None:
prompt_adapters_map = {
prompt_adapter_request.prompt_adapter_id: prompt_adapter_request
for prompt_adapter_request in prompt_adapter_requests
if prompt_adapter_request
}
if len(prompt_adapters_map
) > self._adapter_manager.prompt_adapter_slots:
raise RuntimeError(
f"Number of requested prompt_adapters "
f"({len(prompt_adapters_map)}) is greater "
"than the number of GPU prompt_adapter slots "
f"({self._adapter_manager.prompt_adapter_slots}).")
for prompt_adapter in prompt_adapters_map.values():
self.add_adapter(prompt_adapter)
def add_adapter(self,
prompt_adapter_request: PromptAdapterRequest) -> bool:
if prompt_adapter_request.prompt_adapter_id not in self.list_adapters(
):
# Remove before we load the new prompt_adapter to save memory
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
self._adapter_manager.remove_oldest_adapter()
prompt_adapter = self._load_adapter(prompt_adapter_request)
loaded = self._adapter_manager.add_adapter(prompt_adapter)
else:
# If the prompt_adapter is already loaded, just touch it to
# update its position in the caches
loaded = self._adapter_manager.get_adapter(
prompt_adapter_request.prompt_adapter_id) is not None
self._adapter_manager.activate_adapter(
prompt_adapter_request.prompt_adapter_id)
return loaded

View File

@@ -19,7 +19,6 @@ from vllm.inputs import SingletonInputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
VLLM_TOKEN_ID_ARRAY_TYPE = "l" VLLM_TOKEN_ID_ARRAY_TYPE = "l"
@@ -458,7 +457,6 @@ class Sequence:
block size used by the block manager and cache engine. block size used by the block manager and cache engine.
eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
lora_request: LoRA request. lora_request: LoRA request.
prompt_adapter_request: Prompt Adapter request.
""" """
def __init__( def __init__(
@@ -468,14 +466,12 @@ class Sequence:
block_size: int, block_size: int,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None: ) -> None:
self.seq_id = seq_id self.seq_id = seq_id
self.inputs = inputs self.inputs = inputs
self.block_size = block_size self.block_size = block_size
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.data = SequenceData.from_seqs( self.data = SequenceData.from_seqs(
self.prompt_token_ids, self.prompt_token_ids,
@@ -537,11 +533,6 @@ class Sequence:
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
@property
def prompt_adapter_id(self) -> int:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
def get_output_text_to_return(self, buffer_length: int, def get_output_text_to_return(self, buffer_length: int,
delta: bool) -> str: delta: bool) -> str:
"""If delta is True, only new text since the last call to """If delta is True, only new text since the last call to
@@ -601,12 +592,12 @@ class Sequence:
designed for prefix caching mode. The final sequence hash is determined designed for prefix caching mode. The final sequence hash is determined
by applying token_ids from the sequence's blocks. by applying token_ids from the sequence's blocks.
""" """
if self.prompt_adapter_id == 0 and self.lora_int_id == 0: if self.lora_int_id == 0:
return None return None
# NOTE: If there are additional factors influencing the block aside from # NOTE: If there are additional factors influencing the block aside from
# token_ids, include them as input parameters to the hash. # token_ids, include them as input parameters to the hash.
return hash((self.prompt_adapter_id, self.lora_int_id)) return hash(self.lora_int_id)
def num_hashed_tokens_of_block(self, logical_idx: int): def num_hashed_tokens_of_block(self, logical_idx: int):
return logical_idx * self.block_size + self.block_size return logical_idx * self.block_size + self.block_size
@@ -707,7 +698,6 @@ class SequenceGroup:
encoder_seq: Optional, the single encoder sequence. Should be None encoder_seq: Optional, the single encoder sequence. Should be None
unless you are working with an encoder/decoder model. unless you are working with an encoder/decoder model.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request.
priority: User-defined priority of the request. priority: User-defined priority of the request.
draft_size: The number of speculative tokens plus one from the target draft_size: The number of speculative tokens plus one from the target
model; equal to max number of tokens a step can generate model; equal to max number of tokens a step can generate
@@ -725,7 +715,6 @@ class SequenceGroup:
pooled_data: Optional[torch.Tensor] = None, pooled_data: Optional[torch.Tensor] = None,
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
draft_size: int = 1) -> None: draft_size: int = 1) -> None:
self.request_id = request_id self.request_id = request_id
@@ -747,7 +736,6 @@ class SequenceGroup:
self.state = SequenceGroupState() self.state = SequenceGroupState()
self.pooling_params = pooling_params self.pooling_params = pooling_params
self.pooled_data = pooled_data self.pooled_data = pooled_data
self.prompt_adapter_request = prompt_adapter_request
self.encoder_seq = encoder_seq self.encoder_seq = encoder_seq
self.trace_headers = trace_headers self.trace_headers = trace_headers
self.priority = priority self.priority = priority
@@ -802,16 +790,6 @@ class SequenceGroup:
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
@property
def prompt_adapter_id(self) -> int:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
@property
def prompt_adapter_num_virtual_tokens(self) -> int:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
if self.prompt_adapter_request else 0
def init_multi_step(self, num_steps: int) -> None: def init_multi_step(self, num_steps: int) -> None:
self.state.num_steps = num_steps self.state.num_steps = num_steps
self.state.current_step = 0 self.state.current_step = 0
@@ -1011,7 +989,6 @@ class SequenceGroupMetadata(
(SequenceGroup.encoder_seq). Should be None (SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder unless you are working with an encoder/decoder
model. model.
prompt_adapter_request: Prompt Adapter request.
""" """
request_id: str request_id: str
@@ -1030,7 +1007,6 @@ class SequenceGroupMetadata(
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
encoder_seq_data: Optional[SequenceData] = None encoder_seq_data: Optional[SequenceData] = None
cross_block_table: Optional[list[int]] = None cross_block_table: Optional[list[int]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
token_chunk_size: Optional[int] = None token_chunk_size: Optional[int] = None
### Stateful fields that are lazily defined. ### ### Stateful fields that are lazily defined. ###
@@ -1052,16 +1028,6 @@ class SequenceGroupMetadata(
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
@property
def prompt_adapter_id(self) -> int:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
@property
def prompt_adapter_num_virtual_tokens(self) -> int:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
if self.prompt_adapter_request else 0
# Multi-Step Chunked-Prefill property # Multi-Step Chunked-Prefill property
@property @property
def is_single_step_prompt(self) -> bool: def is_single_step_prompt(self) -> bool:
@@ -1525,7 +1491,6 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
pooled_data=seq_group.pooled_data, pooled_data=seq_group.pooled_data,
encoder_seq=seq_group.encoder_seq, encoder_seq=seq_group.encoder_seq,
trace_headers=seq_group.trace_headers, trace_headers=seq_group.trace_headers,
prompt_adapter_request=seq_group.prompt_adapter_request,
priority=seq_group.priority, priority=seq_group.priority,
) )

View File

@@ -128,10 +128,6 @@ STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers and Flash-Attention are the only "
"backends currently supported with encoder/" "backends currently supported with encoder/"
"decoder models.") "decoder models.")
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
"currently supported with encoder/"
"decoder models.")
# Efficiently import all enc/dec error strings # Efficiently import all enc/dec error strings
# rather than having to import all of the above # rather than having to import all of the above
STR_NOT_IMPL_ENC_DEC_ERR_STRS = { STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
@@ -145,7 +141,6 @@ STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
"STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM, "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
"STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC, "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
"STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND, "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
} }
# Constants related to forcing the attention backend selection # Constants related to forcing the attention backend selection

View File

@@ -20,7 +20,6 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
@@ -221,7 +220,6 @@ class AsyncLLM(EngineClient):
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
) -> RequestOutputCollector: ) -> RequestOutputCollector:
@@ -238,8 +236,7 @@ class AsyncLLM(EngineClient):
# Convert Input --> Request. # Convert Input --> Request.
prompt_str, request = self.processor.process_inputs( prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request, request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers, prompt_adapter_request, tokenization_kwargs, trace_headers, priority, data_parallel_rank)
priority, data_parallel_rank)
if is_pooling or params.n == 1: if is_pooling or params.n == 1:
await self._add_request(request, prompt_str, None, 0, queue) await self._add_request(request, prompt_str, None, 0, queue)
@@ -283,7 +280,6 @@ class AsyncLLM(EngineClient):
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
@@ -314,7 +310,6 @@ class AsyncLLM(EngineClient):
sampling_params, sampling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
) )

View File

@@ -17,7 +17,6 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import ( from vllm.transformers_utils.tokenizer_group import (
TokenizerGroup, init_tokenizer_from_configs) TokenizerGroup, init_tokenizer_from_configs)
@@ -192,7 +191,6 @@ class LLMEngine:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> None: ) -> None:
# Validate the request_id type. # Validate the request_id type.
@@ -203,8 +201,7 @@ class LLMEngine:
# Process raw inputs into the request. # Process raw inputs into the request.
prompt_str, request = self.processor.process_inputs( prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request, request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers, prompt_adapter_request, tokenization_kwargs, trace_headers, priority)
priority)
n = params.n if isinstance(params, SamplingParams) else 1 n = params.n if isinstance(params, SamplingParams) else 1

View File

@@ -16,7 +16,6 @@ from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
@@ -226,7 +225,6 @@ class Processor:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
) -> tuple[Optional[str], EngineCoreRequest]: ) -> tuple[Optional[str], EngineCoreRequest]:
@@ -237,8 +235,6 @@ class Processor:
self._validate_params(params, lora_request) self._validate_params(params, lora_request)
if trace_headers is not None: if trace_headers is not None:
raise ValueError("V1 does not support tracing yet.") raise ValueError("V1 does not support tracing yet.")
if prompt_adapter_request is not None:
raise ValueError("V1 does not support prompt_adapter_request.")
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if data_parallel_rank is not None and not (0 <= data_parallel_rank < if data_parallel_rank is not None and not (0 <= data_parallel_rank <
@@ -253,12 +249,10 @@ class Processor:
# 1. Tokenize text prompt, with LoRA request if one exists. # 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess # 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly. # multimodal data and expand prompt token ids accordingly.
# 3. Apply prompt adapter to prompt token ids if one exists.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=self.use_hash, return_mm_hashes=self.use_hash,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform

View File

@@ -318,8 +318,6 @@ def report_usage_stats(
# Feature flags # Feature flags
"enable_lora": "enable_lora":
bool(vllm_config.lora_config), bool(vllm_config.lora_config),
"enable_prompt_adapter":
bool(vllm_config.prompt_adapter_config),
"enable_prefix_caching": "enable_prefix_caching":
vllm_config.cache_config.enable_prefix_caching, vllm_config.cache_config.enable_prefix_caching,
"enforce_eager": "enforce_eager":

View File

@@ -104,7 +104,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.model_executor.models.utils import set_cpu_offload_max_bytes

View File

@@ -114,7 +114,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.original_parallel_config = original_parallel_config self.original_parallel_config = original_parallel_config
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self.device_config = vllm_config.device_config self.device_config = vllm_config.device_config

View File

@@ -62,7 +62,6 @@ class TPUWorker:
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self.parallel_config.rank = rank self.parallel_config.rank = rank

View File

@@ -91,10 +91,9 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
''' '''
EncoderDecoderModelRunner constructor. EncoderDecoderModelRunner constructor.
`lora_config` and `prompt_adapter_config` are `lora_config` is unused (since these features are not yet supported
unused (since these features are not yet supported for encoder/decoder for encoder/decoder models) but these arguments are present here for
models) but these arguments are present here for compatibility with compatibility with the base-class constructor.
the base-class constructor.
''' '''
self._maybe_force_supported_attention_backend() self._maybe_force_supported_attention_backend()

View File

@@ -45,10 +45,6 @@ from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap, MultiModalKwargs, MultiModalPlaceholderMap,
MultiModalRegistry) MultiModalRegistry)
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache, from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
@@ -95,8 +91,6 @@ class ModelInputForGPU(ModelRunnerInputBase):
lora_mapping: Optional["LoRAMapping"] = None lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None attn_metadata: Optional["AttentionMetadata"] = None
prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None
@@ -113,8 +107,6 @@ class ModelInputForGPU(ModelRunnerInputBase):
"lora_requests": self.lora_requests, "lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping, "lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs, "multi_modal_kwargs": self.multi_modal_kwargs,
"prompt_adapter_mapping": self.prompt_adapter_mapping,
"prompt_adapter_requests": self.prompt_adapter_requests,
"virtual_engine": self.virtual_engine, "virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids, "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids, "finished_requests_ids": self.finished_requests_ids,
@@ -164,8 +156,6 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"lora_requests": self.lora_requests, "lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping, "lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs, "multi_modal_kwargs": self.multi_modal_kwargs,
"prompt_adapter_mapping": self.prompt_adapter_mapping,
"prompt_adapter_requests": self.prompt_adapter_requests,
"virtual_engine": self.virtual_engine, "virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids, "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids, "finished_requests_ids": self.finished_requests_ids,
@@ -212,8 +202,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.lora_index_mapping.clear() # type: ignore self.lora_index_mapping.clear() # type: ignore
self.lora_prompt_mapping.clear() # type: ignore self.lora_prompt_mapping.clear() # type: ignore
self.lora_requests.clear() # type: ignore self.lora_requests.clear() # type: ignore
self.prompt_adapter_index_mapping.clear() # type: ignore
self.prompt_adapter_prompt_mapping.clear() # type: ignore
def __init__( def __init__(
self, self,
@@ -252,11 +240,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
lora_prompt_mapping: Optional[List[List[int]]] = None, lora_prompt_mapping: Optional[List[List[int]]] = None,
lora_requests: Optional[Set[LoRARequest]] = None, lora_requests: Optional[Set[LoRARequest]] = None,
# Prompt adapter inputs.
prompt_adapter_index_mapping: Optional[List[int]] = None,
prompt_adapter_prompt_mapping: Optional[List[int]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
# Multi-modal inputs. # Multi-modal inputs.
multi_modal_kwargs: Optional[MultiModalKwargs] = None, multi_modal_kwargs: Optional[MultiModalKwargs] = None,
multi_modal_placeholder_maps: Optional[Dict[ multi_modal_placeholder_maps: Optional[Dict[
@@ -360,18 +343,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else: else:
self.lora_requests.clear() self.lora_requests.clear()
if prompt_adapter_index_mapping:
self.prompt_adapter_index_mapping = \
prompt_adapter_index_mapping
else:
self.prompt_adapter_index_mapping.clear()
if prompt_adapter_prompt_mapping:
self.prompt_adapter_prompt_mapping = \
prompt_adapter_prompt_mapping
else:
self.prompt_adapter_prompt_mapping.clear()
else: else:
self.input_tokens = input_tokens or [] self.input_tokens = input_tokens or []
self.inputs_embeds = inputs_embeds self.inputs_embeds = inputs_embeds
@@ -390,12 +361,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.lora_prompt_mapping = lora_prompt_mapping or [] self.lora_prompt_mapping = lora_prompt_mapping or []
self.lora_requests = lora_requests or set() self.lora_requests = lora_requests or set()
self.prompt_adapter_index_mapping = (
prompt_adapter_index_mapping or [])
self.prompt_adapter_prompt_mapping = (
prompt_adapter_prompt_mapping or [])
self.prompt_adapter_request = prompt_adapter_request
self.multi_modal_kwargs = multi_modal_kwargs self.multi_modal_kwargs = multi_modal_kwargs
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
self.prefix_cache_hit = prefix_cache_hit self.prefix_cache_hit = prefix_cache_hit
@@ -485,7 +450,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Compute functions for each sequence group. # Compute functions for each sequence group.
# WARNING: The order of the functions matters! # WARNING: The order of the functions matters!
self.per_seq_group_compute_fns = [ self.per_seq_group_compute_fns = [
self._compute_prompt_adapter_input,
self._compute_multi_modal_input, self._compute_multi_modal_input,
] ]
@@ -496,8 +460,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.sliding_window = self.runner.sliding_window self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size self.block_size = self.runner.block_size
self.enable_lora = self.runner.lora_config is not None self.enable_lora = self.runner.lora_config is not None
self.enable_prompt_adapter = (self.runner.prompt_adapter_config
is not None)
# Attention metadata inputs. # Attention metadata inputs.
if self.attn_backend is not None: if self.attn_backend is not None:
@@ -693,34 +655,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else: else:
inter_data.lora_prompt_mapping.append([]) inter_data.lora_prompt_mapping.append([])
def _compute_prompt_adapter_input(
self, inter_data: InterDataForSeqGroup,
seq_group_metadata: SequenceGroupMetadata):
"""If prompt adapter is enabled, compute index and prompt mapping.
"""
# Note that when is_prompt=True, we expect only one sequence
# in the group.
if not self.enable_prompt_adapter:
return
prompt_adapter_id = seq_group_metadata.prompt_adapter_id
if prompt_adapter_id <= 0 or not inter_data.is_prompt:
return
# We expect only one sequence in the group when is_prompt=True.
assert inter_data.n_seqs == 1
query_len = inter_data.query_lens[0]
inter_data.prompt_adapter_request = (
seq_group_metadata.prompt_adapter_request)
num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens
inter_data.prompt_adapter_index_mapping = [
prompt_adapter_id
] * num_tokens + [0] * (query_len - num_tokens)
inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * (
query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs else 1)
def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
seq_group_metadata: SequenceGroupMetadata): seq_group_metadata: SequenceGroupMetadata):
"""If multi-modal data is given, add it to the input.""" """If multi-modal data is given, add it to the input."""
@@ -1009,29 +943,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_mapping=lora_prompt_mapping, prompt_mapping=lora_prompt_mapping,
is_prefill=not self.decode_only)) is_prefill=not self.decode_only))
# Prompt adapter data.
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
prompt_adapter_mapping = None
if self.enable_prompt_adapter:
prompt_adapter_requests = set(
data.prompt_adapter_request for data in self.inter_data_list
if data.prompt_adapter_request is not None)
prompt_adapter_index_mapping = flatten_2d_lists([
inter_data.prompt_adapter_index_mapping
for inter_data in self.inter_data_list
])
if cuda_graph_pad_size:
prompt_adapter_index_mapping.extend(
itertools.repeat(0, cuda_graph_pad_size))
prompt_adapter_prompt_mapping = flatten_2d_lists([
inter_data.prompt_adapter_prompt_mapping
for inter_data in self.inter_data_list
])
prompt_adapter_mapping = PromptAdapterMapping(
prompt_adapter_index_mapping,
prompt_adapter_prompt_mapping,
)
# Multi-modal data. # Multi-modal data.
multi_modal_kwargs_list = [ multi_modal_kwargs_list = [
data.multi_modal_kwargs for data in self.inter_data_list data.multi_modal_kwargs for data in self.inter_data_list
@@ -1051,9 +962,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
lora_requests=lora_requests, lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs, multi_modal_kwargs=multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids, request_ids_to_seq_ids=request_ids_to_seq_ids,
finished_requests_ids=self.finished_requests_ids, finished_requests_ids=self.finished_requests_ids)
prompt_adapter_mapping=prompt_adapter_mapping,
prompt_adapter_requests=prompt_adapter_requests)
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
@@ -1148,7 +1057,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.model: nn.Module # Set after load_model self.model: nn.Module # Set after load_model
# Set after load_model. # Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
self.sampler = get_sampler() self.sampler = get_sampler()
set_cpu_offload_max_bytes( set_cpu_offload_max_bytes(
@@ -1207,14 +1115,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
logger.info("Model loading took %.4f GiB and %.6f seconds", logger.info("Model loading took %.4f GiB and %.6f seconds",
self.model_memory_usage / GiB_bytes, self.model_memory_usage / GiB_bytes,
time_after_load - time_before_load) time_after_load - time_before_load)
if self.prompt_adapter_config:
self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens, self.device,
self.prompt_adapter_config)
self.model = (
self.prompt_adapter_manager.create_prompt_adapter_manager(
self.model))
if self.vllm_config.compilation_config.level ==\ if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
@@ -1466,40 +1367,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_adapters() return self.lora_manager.list_adapters()
def remove_all_prompt_adapters(self):
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
self.prompt_adapter_manager.remove_all_adapters()
def set_active_prompt_adapters(
self, prompt_adapter_requests: Set[PromptAdapterRequest],
prompt_adapter_mapping: PromptAdapterMapping) -> None:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
self.prompt_adapter_manager.set_active_adapters(
prompt_adapter_requests, prompt_adapter_mapping)
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.list_adapters()
@torch.inference_mode() @torch.inference_mode()
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
"""Cuda graph capture a model. """Cuda graph capture a model.
@@ -1609,13 +1476,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.set_active_loras(set([dummy_lora_request]), self.set_active_loras(set([dummy_lora_request]),
lora_mapping) lora_mapping)
if self.prompt_adapter_config:
prompt_adapter_mapping = PromptAdapterMapping(
[-1] * batch_size,
[-1] * batch_size,
)
self.set_active_prompt_adapters(
set(), prompt_adapter_mapping)
graph_runner = CUDAGraphRunner( graph_runner = CUDAGraphRunner(
self.model, self.attn_backend.get_name(), self.model, self.attn_backend.get_name(),
self.attn_state.graph_clone(batch_size), self.attn_state.graph_clone(batch_size),
@@ -1776,13 +1636,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self.set_active_loras(model_input.lora_requests, self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping) model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
self.attn_state.begin_forward(model_input) self.attn_state.begin_forward(model_input)
# Currently cuda graph is only supported by the decode phase. # Currently cuda graph is only supported by the decode phase.

View File

@@ -190,7 +190,6 @@ class ModelRunnerBase(ABC, Generic[T]):
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
# Map of request_id -> generator used for seeded random sampling # Map of request_id -> generator used for seeded random sampling

View File

@@ -288,9 +288,6 @@ class StatefulModelInput(BroadcastableModelInput):
assert fmi.lora_requests is not None assert fmi.lora_requests is not None
assert len(fmi.lora_requests) == 0 assert len(fmi.lora_requests) == 0
assert fmi.attn_metadata is not None assert fmi.attn_metadata is not None
assert fmi.prompt_adapter_mapping is None
assert fmi.prompt_adapter_requests is not None
assert len(fmi.prompt_adapter_requests) == 0
assert fmi.multi_modal_kwargs is not None assert fmi.multi_modal_kwargs is not None
assert len(fmi.multi_modal_kwargs) == 0 assert len(fmi.multi_modal_kwargs) == 0

View File

@@ -64,13 +64,6 @@ class PoolingModelRunner(
self.set_active_loras(model_input.lora_requests, self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping) model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
# Currently cuda graph is only supported by the decode phase. # Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata prefill_meta = model_input.attn_metadata.prefill_metadata

View File

@@ -47,7 +47,3 @@ def assert_enc_dec_mr_supported_scenario(
if enc_dec_mr.scheduler_config.num_lookahead_slots > 0: if enc_dec_mr.scheduler_config.num_lookahead_slots > 0:
raise NotImplementedError( raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC']) STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])
if enc_dec_mr.prompt_adapter_config is not None:
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER'])

View File

@@ -22,7 +22,6 @@ from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta) SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache, from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,
@@ -513,19 +512,6 @@ class Worker(LocalOrDistributedWorkerBase):
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
return self.model_runner.list_loras() return self.model_runner.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.model_runner.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_runner.remove_lora(prompt_adapter_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_runner.pin_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
return self.model_runner.list_prompt_adapters()
@property @property
def max_model_len(self) -> int: def max_model_len(self) -> int:
return self.model_config.max_model_len return self.model_config.max_model_len

View File

@@ -49,7 +49,6 @@ class WorkerBase:
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self.kv_transfer_config = vllm_config.kv_transfer_config self.kv_transfer_config = vllm_config.kv_transfer_config
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config