[Bugfix] Fix lora tests (#34834)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Cyrus Leung
2026-02-19 05:22:31 +08:00
committed by GitHub
parent 847a57cd12
commit 61cf087680
3 changed files with 28 additions and 44 deletions

View File

@@ -153,5 +153,5 @@ def test_default_mm_lora_does_not_expand_string_reqs(vllm_runner):
# Then check to make sure the submitted lora request # Then check to make sure the submitted lora request
# and text prompt were zipped together correctly # and text prompt were zipped together correctly
engine_args, engine_kwargs = mock_add_request.call_args engine_args, engine_kwargs = mock_add_request.call_args
assert engine_args[1]["prompt"] == AUDIO_PROMPT
assert engine_kwargs["lora_request"] is None assert engine_kwargs["lora_request"] is None
assert engine_kwargs["prompt_text"] == AUDIO_PROMPT

View File

@@ -88,9 +88,8 @@ class Qwen2VLTester:
# Validate outputs # Validate outputs
for generated, expected in zip(generated_texts, expected_outputs): for generated, expected in zip(generated_texts, expected_outputs):
assert expected.startswith(generated), ( assert expected.startswith(generated), (
f"Generated text {generated} doesn't " f"Generated text {generated} doesn't match expected pattern {expected}"
) )
f"match expected pattern {expected}"
def run_beam_search_test( def run_beam_search_test(
self, self,
@@ -118,11 +117,14 @@ class Qwen2VLTester:
inputs, beam_search_params, lora_request=lora_request inputs, beam_search_params, lora_request=lora_request
) )
for output_obj, expected_outs in zip(outputs, expected_outputs): for output_obj, expected_texts in zip(outputs, expected_outputs):
output_texts = [seq.text for seq in output_obj.sequences] output_texts = [seq.text for seq in output_obj.sequences]
assert output_texts == expected_outs, (
f"Generated texts {output_texts} do not match expected {expected_outs}" for output_text, expected_text in zip(output_texts, expected_texts):
) # noqa: E501 # NOTE beam search .text contains the whole text including inputs
assert output_text.endswith(expected_text), (
f"Generated {output_text} does not match expected {expected_text}"
)
TEST_IMAGES = [ TEST_IMAGES = [
@@ -151,11 +153,10 @@ EXPECTED_OUTPUTS_VISION_NO_CONNECTOR = [
"A closeup shot of the Tokyo Skytree with pink flowers in the foreground.", "A closeup shot of the Tokyo Skytree with pink flowers in the foreground.",
] ]
# NOTE - beam search .text contains the whole text
EXPECTED_BEAM_SEARCH_OUTPUTS = [ EXPECTED_BEAM_SEARCH_OUTPUTS = [
[ [
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic skyscraper stands", # noqa: E501 "A majestic skyscraper stands",
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic tower stands tall", # noqa: E501 "A majestic tower stands tall",
], ],
] ]

View File

@@ -542,51 +542,31 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput) return self.engine_class.validate_outputs(outputs, RequestOutput)
def _resolve_lora_reqs( def _resolve_mm_lora(
self,
prompts: Sequence[ProcessorInputs],
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
):
lora_config = self.llm_engine.vllm_config.lora_config
seq_lora_requests = self._lora_request_to_seq(lora_request, len(prompts))
if (
lora_config is None
or not self.model_config.is_multimodal_model
or (lora_config and lora_config.default_mm_loras is None)
):
return seq_lora_requests
return [
self._resolve_single_prompt_mm_lora(
prompt,
lora_req,
lora_config.default_mm_loras,
)
for prompt, lora_req in zip(prompts, seq_lora_requests)
]
def _resolve_single_prompt_mm_lora(
self, self,
prompt: ProcessorInputs, prompt: ProcessorInputs,
lora_request: LoRARequest | None, lora_request: LoRARequest | None,
default_mm_loras: dict[str, str] | None, ) -> LoRARequest | None:
): if prompt["type"] != "multimodal":
if not default_mm_loras or prompt["type"] != "multimodal": return lora_request
lora_config = self.llm_engine.vllm_config.lora_config
default_mm_loras = None if lora_config is None else lora_config.default_mm_loras
if not default_mm_loras:
return lora_request return lora_request
prompt_modalities = prompt["mm_placeholders"].keys() prompt_modalities = prompt["mm_placeholders"].keys()
intersection = set(prompt_modalities).intersection(default_mm_loras.keys()) intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
if not intersection: if not intersection:
return lora_request return lora_request
if len(intersection) > 1: if len(intersection) > 1:
# TODO: Would be nice to be able to have multiple loras per prompt # TODO: Would be nice to be able to have multiple loras per prompt
logger.warning( logger.warning(
"Multiple modality specific loras were registered and would be " "Multiple modality specific loras were registered and would be "
"used by a single prompt consuming several modalities; " "used by a single prompt consuming several modalities; "
"currently we only support one lora per request; as such, " "currently we only support one lora per request; as such, "
" lora(s) registered with modalities: %s" "lora(s) registered with modalities: %s will be skipped",
" will be skipped",
intersection, intersection,
) )
return lora_request return lora_request
@@ -1915,7 +1895,10 @@ class LLM:
request_id = self._add_request( request_id = self._add_request(
prompt, prompt,
params[i], params[i],
lora_request=None if lora_requests is None else lora_requests[i], lora_request=self._resolve_mm_lora(
prompt,
None if lora_requests is None else lora_requests[i],
),
priority=0 if priorities is None else priorities[i], priority=0 if priorities is None else priorities[i],
) )
added_request_ids.append(request_id) added_request_ids.append(request_id)