[Bugfix] Fix lora tests (#34834)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -153,5 +153,5 @@ def test_default_mm_lora_does_not_expand_string_reqs(vllm_runner):
|
|||||||
# Then check to make sure the submitted lora request
|
# Then check to make sure the submitted lora request
|
||||||
# and text prompt were zipped together correctly
|
# and text prompt were zipped together correctly
|
||||||
engine_args, engine_kwargs = mock_add_request.call_args
|
engine_args, engine_kwargs = mock_add_request.call_args
|
||||||
|
assert engine_args[1]["prompt"] == AUDIO_PROMPT
|
||||||
assert engine_kwargs["lora_request"] is None
|
assert engine_kwargs["lora_request"] is None
|
||||||
assert engine_kwargs["prompt_text"] == AUDIO_PROMPT
|
|
||||||
|
|||||||
@@ -88,9 +88,8 @@ class Qwen2VLTester:
|
|||||||
# Validate outputs
|
# Validate outputs
|
||||||
for generated, expected in zip(generated_texts, expected_outputs):
|
for generated, expected in zip(generated_texts, expected_outputs):
|
||||||
assert expected.startswith(generated), (
|
assert expected.startswith(generated), (
|
||||||
f"Generated text {generated} doesn't "
|
f"Generated text {generated} doesn't match expected pattern {expected}"
|
||||||
)
|
)
|
||||||
f"match expected pattern {expected}"
|
|
||||||
|
|
||||||
def run_beam_search_test(
|
def run_beam_search_test(
|
||||||
self,
|
self,
|
||||||
@@ -118,11 +117,14 @@ class Qwen2VLTester:
|
|||||||
inputs, beam_search_params, lora_request=lora_request
|
inputs, beam_search_params, lora_request=lora_request
|
||||||
)
|
)
|
||||||
|
|
||||||
for output_obj, expected_outs in zip(outputs, expected_outputs):
|
for output_obj, expected_texts in zip(outputs, expected_outputs):
|
||||||
output_texts = [seq.text for seq in output_obj.sequences]
|
output_texts = [seq.text for seq in output_obj.sequences]
|
||||||
assert output_texts == expected_outs, (
|
|
||||||
f"Generated texts {output_texts} do not match expected {expected_outs}"
|
for output_text, expected_text in zip(output_texts, expected_texts):
|
||||||
) # noqa: E501
|
# NOTE beam search .text contains the whole text including inputs
|
||||||
|
assert output_text.endswith(expected_text), (
|
||||||
|
f"Generated {output_text} does not match expected {expected_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
TEST_IMAGES = [
|
TEST_IMAGES = [
|
||||||
@@ -151,11 +153,10 @@ EXPECTED_OUTPUTS_VISION_NO_CONNECTOR = [
|
|||||||
"A closeup shot of the Tokyo Skytree with pink flowers in the foreground.",
|
"A closeup shot of the Tokyo Skytree with pink flowers in the foreground.",
|
||||||
]
|
]
|
||||||
|
|
||||||
# NOTE - beam search .text contains the whole text
|
|
||||||
EXPECTED_BEAM_SEARCH_OUTPUTS = [
|
EXPECTED_BEAM_SEARCH_OUTPUTS = [
|
||||||
[
|
[
|
||||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic skyscraper stands", # noqa: E501
|
"A majestic skyscraper stands",
|
||||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic tower stands tall", # noqa: E501
|
"A majestic tower stands tall",
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -542,51 +542,31 @@ class LLM:
|
|||||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||||
|
|
||||||
def _resolve_lora_reqs(
|
def _resolve_mm_lora(
|
||||||
self,
|
|
||||||
prompts: Sequence[ProcessorInputs],
|
|
||||||
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
|
|
||||||
):
|
|
||||||
lora_config = self.llm_engine.vllm_config.lora_config
|
|
||||||
seq_lora_requests = self._lora_request_to_seq(lora_request, len(prompts))
|
|
||||||
|
|
||||||
if (
|
|
||||||
lora_config is None
|
|
||||||
or not self.model_config.is_multimodal_model
|
|
||||||
or (lora_config and lora_config.default_mm_loras is None)
|
|
||||||
):
|
|
||||||
return seq_lora_requests
|
|
||||||
|
|
||||||
return [
|
|
||||||
self._resolve_single_prompt_mm_lora(
|
|
||||||
prompt,
|
|
||||||
lora_req,
|
|
||||||
lora_config.default_mm_loras,
|
|
||||||
)
|
|
||||||
for prompt, lora_req in zip(prompts, seq_lora_requests)
|
|
||||||
]
|
|
||||||
|
|
||||||
def _resolve_single_prompt_mm_lora(
|
|
||||||
self,
|
self,
|
||||||
prompt: ProcessorInputs,
|
prompt: ProcessorInputs,
|
||||||
lora_request: LoRARequest | None,
|
lora_request: LoRARequest | None,
|
||||||
default_mm_loras: dict[str, str] | None,
|
) -> LoRARequest | None:
|
||||||
):
|
if prompt["type"] != "multimodal":
|
||||||
if not default_mm_loras or prompt["type"] != "multimodal":
|
return lora_request
|
||||||
|
|
||||||
|
lora_config = self.llm_engine.vllm_config.lora_config
|
||||||
|
default_mm_loras = None if lora_config is None else lora_config.default_mm_loras
|
||||||
|
if not default_mm_loras:
|
||||||
return lora_request
|
return lora_request
|
||||||
|
|
||||||
prompt_modalities = prompt["mm_placeholders"].keys()
|
prompt_modalities = prompt["mm_placeholders"].keys()
|
||||||
intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
|
intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
|
||||||
if not intersection:
|
if not intersection:
|
||||||
return lora_request
|
return lora_request
|
||||||
|
|
||||||
if len(intersection) > 1:
|
if len(intersection) > 1:
|
||||||
# TODO: Would be nice to be able to have multiple loras per prompt
|
# TODO: Would be nice to be able to have multiple loras per prompt
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Multiple modality specific loras were registered and would be"
|
"Multiple modality specific loras were registered and would be "
|
||||||
" used by a single prompt consuming several modalities; "
|
"used by a single prompt consuming several modalities; "
|
||||||
" currently we only support one lora per request; as such,"
|
"currently we only support one lora per request; as such, "
|
||||||
" lora(s) registered with modalities: %s"
|
"lora(s) registered with modalities: %s will be skipped",
|
||||||
" will be skipped",
|
|
||||||
intersection,
|
intersection,
|
||||||
)
|
)
|
||||||
return lora_request
|
return lora_request
|
||||||
@@ -1915,7 +1895,10 @@ class LLM:
|
|||||||
request_id = self._add_request(
|
request_id = self._add_request(
|
||||||
prompt,
|
prompt,
|
||||||
params[i],
|
params[i],
|
||||||
lora_request=None if lora_requests is None else lora_requests[i],
|
lora_request=self._resolve_mm_lora(
|
||||||
|
prompt,
|
||||||
|
None if lora_requests is None else lora_requests[i],
|
||||||
|
),
|
||||||
priority=0 if priorities is None else priorities[i],
|
priority=0 if priorities is None else priorities[i],
|
||||||
)
|
)
|
||||||
added_request_ids.append(request_id)
|
added_request_ids.append(request_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user