[Core] Add Lora Support to Beam Search (#18346)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
@@ -313,3 +313,37 @@ async def test_loading_invalid_adapters_does_not_break_others(
|
|||||||
prompt=["Hello there", "Foo bar bazz buzz"],
|
prompt=["Hello there", "Foo bar bazz buzz"],
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_beam_search_with_lora_adapters(
|
||||||
|
client: openai.AsyncOpenAI,
|
||||||
|
tmp_path,
|
||||||
|
zephyr_lora_files,
|
||||||
|
):
|
||||||
|
"""Validate that async beam search can be used with lora."""
|
||||||
|
|
||||||
|
async def load_and_run_adapter(adapter_name: str):
|
||||||
|
await client.post("load_lora_adapter",
|
||||||
|
cast_to=str,
|
||||||
|
body={
|
||||||
|
"lora_name": adapter_name,
|
||||||
|
"lora_path": str(zephyr_lora_files)
|
||||||
|
})
|
||||||
|
for _ in range(3):
|
||||||
|
await client.completions.create(
|
||||||
|
model=adapter_name,
|
||||||
|
prompt=["Hello there", "Foo bar bazz buzz"],
|
||||||
|
max_tokens=5,
|
||||||
|
extra_body=dict(use_beam_search=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_tasks = []
|
||||||
|
for i in range(3):
|
||||||
|
lora_tasks.append(
|
||||||
|
asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
|
||||||
|
|
||||||
|
results, _ = await asyncio.wait(lora_tasks)
|
||||||
|
|
||||||
|
for r in results:
|
||||||
|
assert not isinstance(r, Exception), f"Got exception {r}"
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import vllm
|
|||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.sampling_params import BeamSearchParams
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=not current_platform.is_cpu())
|
@pytest.fixture(autouse=not current_platform.is_cpu())
|
||||||
@@ -69,7 +70,7 @@ class Qwen2VLTester:
|
|||||||
expected_outputs: list[str],
|
expected_outputs: list[str],
|
||||||
lora_id: Optional[int] = None,
|
lora_id: Optional[int] = None,
|
||||||
temperature: float = 0,
|
temperature: float = 0,
|
||||||
max_tokens: int = 5) -> list[str]:
|
max_tokens: int = 5):
|
||||||
|
|
||||||
sampling_params = vllm.SamplingParams(
|
sampling_params = vllm.SamplingParams(
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -97,7 +98,35 @@ class Qwen2VLTester:
|
|||||||
generated), f"Generated text {generated} doesn't "
|
generated), f"Generated text {generated} doesn't "
|
||||||
f"match expected pattern {expected}"
|
f"match expected pattern {expected}"
|
||||||
|
|
||||||
return generated_texts
|
def run_beam_search_test(self,
|
||||||
|
images: list[ImageAsset],
|
||||||
|
expected_outputs: list[list[str]],
|
||||||
|
lora_id: Optional[int] = None,
|
||||||
|
temperature: float = 0,
|
||||||
|
beam_width: int = 2,
|
||||||
|
max_tokens: int = 5):
|
||||||
|
|
||||||
|
beam_search_params = BeamSearchParams(beam_width=beam_width,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature)
|
||||||
|
|
||||||
|
inputs = [{
|
||||||
|
"prompt": self.PROMPT_TEMPLATE,
|
||||||
|
"multi_modal_data": {
|
||||||
|
"image": asset.pil_image
|
||||||
|
},
|
||||||
|
} for asset in images]
|
||||||
|
|
||||||
|
lora_request = LoRARequest(str(lora_id), lora_id,
|
||||||
|
self.config.lora_path)
|
||||||
|
outputs = self.llm.beam_search(inputs,
|
||||||
|
beam_search_params,
|
||||||
|
lora_request=lora_request)
|
||||||
|
|
||||||
|
for output_obj, expected_outs in zip(outputs, expected_outputs):
|
||||||
|
output_texts = [seq.text for seq in output_obj.sequences]
|
||||||
|
assert output_texts == expected_outs, \
|
||||||
|
f"Generated texts {output_texts} do not match expected {expected_outs}" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
TEST_IMAGES = [
|
TEST_IMAGES = [
|
||||||
@@ -110,6 +139,14 @@ EXPECTED_OUTPUTS = [
|
|||||||
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
|
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# NOTE - beam search .text contains the whole text
|
||||||
|
EXPECTED_BEAM_SEARCH_OUTPUTS = [
|
||||||
|
[
|
||||||
|
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic skyscraper stands", # noqa: E501
|
||||||
|
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic tower stands tall", # noqa: E501
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
QWEN2VL_MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct"
|
QWEN2VL_MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct"
|
||||||
QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"
|
QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
|
|
||||||
@@ -130,6 +167,27 @@ def test_qwen2vl_lora(qwen2vl_lora_files):
|
|||||||
lora_id=lora_id)
|
lora_id=lora_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(
|
||||||
|
current_platform.is_rocm(),
|
||||||
|
reason="Qwen2-VL dependency xformers incompatible with ROCm")
|
||||||
|
def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
|
||||||
|
"""Test Qwen 2.0 VL model with LoRA through beam search."""
|
||||||
|
config = TestConfig(model_path=QWEN2VL_MODEL_PATH,
|
||||||
|
lora_path=qwen2vl_lora_files)
|
||||||
|
tester = Qwen2VLTester(config)
|
||||||
|
|
||||||
|
# Test with different LoRA IDs
|
||||||
|
for lora_id in [1, 2]:
|
||||||
|
# NOTE currently, we only test cherry blossom since stop sign
|
||||||
|
# output is slightly different for v1; - the root cause is likely
|
||||||
|
# independent of the intent of this test, which is to ensure beam
|
||||||
|
# search passes through lora through correctly.
|
||||||
|
tester.run_beam_search_test(
|
||||||
|
[ImageAsset("cherry_blossom")],
|
||||||
|
expected_outputs=EXPECTED_BEAM_SEARCH_OUTPUTS,
|
||||||
|
lora_id=lora_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(
|
@pytest.mark.xfail(
|
||||||
current_platform.is_rocm(),
|
current_platform.is_rocm(),
|
||||||
reason="Qwen2.5-VL dependency xformers incompatible with ROCm",
|
reason="Qwen2.5-VL dependency xformers incompatible with ROCm",
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -19,6 +20,7 @@ class BeamSearchSequence:
|
|||||||
# The tokens includes the prompt.
|
# The tokens includes the prompt.
|
||||||
tokens: list[int]
|
tokens: list[int]
|
||||||
logprobs: list[dict[int, Logprob]]
|
logprobs: list[dict[int, Logprob]]
|
||||||
|
lora_request: Optional[LoRARequest] = None
|
||||||
cum_logprob: float = 0.0
|
cum_logprob: float = 0.0
|
||||||
text: Optional[str] = None
|
text: Optional[str] = None
|
||||||
finish_reason: Optional[str] = None
|
finish_reason: Optional[str] = None
|
||||||
@@ -41,6 +43,7 @@ class BeamSearchInstance:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
prompt_tokens: list[int],
|
prompt_tokens: list[int],
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
logprobs: Optional[list[dict[int, Logprob]]] = None,
|
logprobs: Optional[list[dict[int, Logprob]]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -48,6 +51,7 @@ class BeamSearchInstance:
|
|||||||
BeamSearchSequence(
|
BeamSearchSequence(
|
||||||
tokens=prompt_tokens,
|
tokens=prompt_tokens,
|
||||||
logprobs=[] if logprobs is None else list(logprobs),
|
logprobs=[] if logprobs is None else list(logprobs),
|
||||||
|
lora_request=lora_request,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ class EngineClient(ABC):
|
|||||||
prompt: PromptType,
|
prompt: PromptType,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
params: BeamSearchParams,
|
params: BeamSearchParams,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
|
||||||
beam_width = params.beam_width
|
beam_width = params.beam_width
|
||||||
@@ -106,27 +107,31 @@ class EngineClient(ABC):
|
|||||||
cum_logprob=0,
|
cum_logprob=0,
|
||||||
logprobs=[],
|
logprobs=[],
|
||||||
multi_modal_data=multi_modal_data,
|
multi_modal_data=multi_modal_data,
|
||||||
mm_processor_kwargs=mm_processor_kwargs)
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
|
lora_request=lora_request)
|
||||||
]
|
]
|
||||||
completed = []
|
completed = []
|
||||||
|
|
||||||
for _ in range(max_tokens):
|
for _ in range(max_tokens):
|
||||||
prompts_batch = [
|
prompts_batch, lora_req_batch = zip(*[(
|
||||||
TokensPrompt(prompt_token_ids=beam.tokens,
|
TokensPrompt(prompt_token_ids=beam.tokens,
|
||||||
multi_modal_data=beam.multi_modal_data,
|
multi_modal_data=beam.multi_modal_data,
|
||||||
mm_processor_kwargs=beam.mm_processor_kwargs)
|
mm_processor_kwargs=beam.mm_processor_kwargs),
|
||||||
for beam in all_beams
|
beam.lora_request,
|
||||||
]
|
) for beam in all_beams])
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
||||||
request_id = f"beam_search-{random_uuid()}"
|
request_id = f"beam_search-{random_uuid()}"
|
||||||
for i, individual_prompt in enumerate(prompts_batch):
|
for i, (individual_prompt,
|
||||||
|
lora_req) in enumerate(zip(prompts_batch, lora_req_batch)):
|
||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
collect_from_async_generator(
|
collect_from_async_generator(
|
||||||
self.generate(individual_prompt, beam_search_params,
|
self.generate(individual_prompt,
|
||||||
request_id_item)))
|
beam_search_params,
|
||||||
|
request_id_item,
|
||||||
|
lora_request=lora_req)))
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
|
|
||||||
output = await asyncio.gather(*tasks)
|
output = await asyncio.gather(*tasks)
|
||||||
@@ -159,6 +164,7 @@ class EngineClient(ABC):
|
|||||||
tokens=current_beam.tokens + [token_id],
|
tokens=current_beam.tokens + [token_id],
|
||||||
logprobs=current_beam.logprobs +
|
logprobs=current_beam.logprobs +
|
||||||
[logprobs],
|
[logprobs],
|
||||||
|
lora_request=current_beam.lora_request,
|
||||||
cum_logprob=current_beam.cum_logprob +
|
cum_logprob=current_beam.cum_logprob +
|
||||||
logprob_obj.logprob,
|
logprob_obj.logprob,
|
||||||
multi_modal_data=current_beam.
|
multi_modal_data=current_beam.
|
||||||
|
|||||||
@@ -522,10 +522,28 @@ class LLM:
|
|||||||
executor = self.llm_engine.model_executor
|
executor = self.llm_engine.model_executor
|
||||||
return executor.apply_model(func)
|
return executor.apply_model(func)
|
||||||
|
|
||||||
|
def _get_beam_search_lora_requests(
|
||||||
|
self,
|
||||||
|
lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
|
||||||
|
prompts: list[Union[TokensPrompt, TextPrompt]],
|
||||||
|
) -> list[Optional[LoRARequest]]:
|
||||||
|
"""Get the optional lora request corresponding to each prompt."""
|
||||||
|
if isinstance(lora_request,
|
||||||
|
Sequence) and len(lora_request) != len(prompts):
|
||||||
|
raise ValueError(
|
||||||
|
"Lora request list should be the same length as the prompts")
|
||||||
|
return lora_request
|
||||||
|
|
||||||
|
if lora_request is None or isinstance(lora_request, LoRARequest):
|
||||||
|
return [lora_request] * len(prompts)
|
||||||
|
|
||||||
|
raise TypeError(f"Invalid lora_request type {type(lora_request)}")
|
||||||
|
|
||||||
def beam_search(
|
def beam_search(
|
||||||
self,
|
self,
|
||||||
prompts: list[Union[TokensPrompt, TextPrompt]],
|
prompts: list[Union[TokensPrompt, TextPrompt]],
|
||||||
params: BeamSearchParams,
|
params: BeamSearchParams,
|
||||||
|
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||||
) -> list[BeamSearchOutput]:
|
) -> list[BeamSearchOutput]:
|
||||||
"""
|
"""
|
||||||
Generate sequences using beam search.
|
Generate sequences using beam search.
|
||||||
@@ -534,6 +552,7 @@ class LLM:
|
|||||||
prompts: A list of prompts. Each prompt can be a string or a list
|
prompts: A list of prompts. Each prompt can be a string or a list
|
||||||
of token IDs.
|
of token IDs.
|
||||||
params: The beam search parameters.
|
params: The beam search parameters.
|
||||||
|
lora_request: LoRA request to use for generation, if any.
|
||||||
"""
|
"""
|
||||||
# TODO: how does beam search work together with length penalty,
|
# TODO: how does beam search work together with length penalty,
|
||||||
# frequency, penalty, and stopping criteria, etc.?
|
# frequency, penalty, and stopping criteria, etc.?
|
||||||
@@ -543,6 +562,9 @@ class LLM:
|
|||||||
ignore_eos = params.ignore_eos
|
ignore_eos = params.ignore_eos
|
||||||
length_penalty = params.length_penalty
|
length_penalty = params.length_penalty
|
||||||
|
|
||||||
|
lora_requests = self._get_beam_search_lora_requests(
|
||||||
|
lora_request, prompts)
|
||||||
|
|
||||||
def sort_beams_key(x: BeamSearchSequence) -> float:
|
def sort_beams_key(x: BeamSearchSequence) -> float:
|
||||||
return get_beam_search_score(x.tokens, x.cum_logprob,
|
return get_beam_search_score(x.tokens, x.cum_logprob,
|
||||||
tokenizer.eos_token_id,
|
tokenizer.eos_token_id,
|
||||||
@@ -570,7 +592,7 @@ class LLM:
|
|||||||
temperature=temperature)
|
temperature=temperature)
|
||||||
instances: list[BeamSearchInstance] = []
|
instances: list[BeamSearchInstance] = []
|
||||||
|
|
||||||
for prompt in prompts:
|
for lora_req, prompt in zip(lora_requests, prompts):
|
||||||
# Add multimodal processor kwargs & data
|
# Add multimodal processor kwargs & data
|
||||||
mm_kwargs = {}
|
mm_kwargs = {}
|
||||||
if "multi_modal_data" in prompt:
|
if "multi_modal_data" in prompt:
|
||||||
@@ -586,7 +608,12 @@ class LLM:
|
|||||||
prompt_tokens = tokenizer.encode(prompt["prompt"])
|
prompt_tokens = tokenizer.encode(prompt["prompt"])
|
||||||
|
|
||||||
instances.append(
|
instances.append(
|
||||||
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
|
BeamSearchInstance(
|
||||||
|
prompt_tokens,
|
||||||
|
lora_request=lora_req,
|
||||||
|
logprobs=None,
|
||||||
|
**mm_kwargs,
|
||||||
|
), )
|
||||||
|
|
||||||
for _ in range(max_tokens):
|
for _ in range(max_tokens):
|
||||||
all_beams: list[BeamSearchSequence] = list(
|
all_beams: list[BeamSearchSequence] = list(
|
||||||
@@ -600,15 +627,17 @@ class LLM:
|
|||||||
if len(all_beams) == 0:
|
if len(all_beams) == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
prompts_batch = [
|
# create the corresponding batch entries for prompt & optional lora
|
||||||
create_tokens_prompt_from_beam(beam) for beam in all_beams
|
prompts_batch, lora_req_batch = zip(
|
||||||
]
|
*[(create_tokens_prompt_from_beam(beam), beam.lora_request)
|
||||||
|
for beam in all_beams])
|
||||||
|
|
||||||
# only runs for one step
|
# only runs for one step
|
||||||
# we don't need to use tqdm here
|
# we don't need to use tqdm here
|
||||||
output = self.generate(prompts_batch,
|
output = self.generate(prompts_batch,
|
||||||
sampling_params=beam_search_params,
|
sampling_params=beam_search_params,
|
||||||
use_tqdm=False)
|
use_tqdm=False,
|
||||||
|
lora_request=lora_req_batch)
|
||||||
|
|
||||||
for (start, end), instance in zip(instance_start_and_end,
|
for (start, end), instance in zip(instance_start_and_end,
|
||||||
instances):
|
instances):
|
||||||
@@ -626,6 +655,7 @@ class LLM:
|
|||||||
new_beam = BeamSearchSequence(
|
new_beam = BeamSearchSequence(
|
||||||
tokens=current_beam.tokens + [token_id],
|
tokens=current_beam.tokens + [token_id],
|
||||||
logprobs=current_beam.logprobs + [logprobs],
|
logprobs=current_beam.logprobs + [logprobs],
|
||||||
|
lora_request=current_beam.lora_request,
|
||||||
cum_logprob=current_beam.cum_logprob +
|
cum_logprob=current_beam.cum_logprob +
|
||||||
logprob_obj.logprob,
|
logprob_obj.logprob,
|
||||||
multi_modal_data=current_beam.multi_modal_data,
|
multi_modal_data=current_beam.multi_modal_data,
|
||||||
|
|||||||
@@ -236,6 +236,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
prompt=engine_prompt,
|
prompt=engine_prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
generator = self.engine_client.generate(
|
generator = self.engine_client.generate(
|
||||||
|
|||||||
@@ -186,6 +186,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
prompt=engine_prompt,
|
prompt=engine_prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
generator = self.engine_client.generate(
|
generator = self.engine_client.generate(
|
||||||
|
|||||||
Reference in New Issue
Block a user