[Bugfix] Fix lora tests (#34834)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Cyrus Leung
2026-02-19 05:22:31 +08:00
committed by GitHub
parent 847a57cd12
commit 61cf087680
3 changed files with 28 additions and 44 deletions

View File

@@ -153,5 +153,5 @@ def test_default_mm_lora_does_not_expand_string_reqs(vllm_runner):
# Then check to make sure the submitted lora request # Then check to make sure the submitted lora request
# and text prompt were zipped together correctly # and text prompt were zipped together correctly
engine_args, engine_kwargs = mock_add_request.call_args engine_args, engine_kwargs = mock_add_request.call_args
assert engine_args[1]["prompt"] == AUDIO_PROMPT
assert engine_kwargs["lora_request"] is None assert engine_kwargs["lora_request"] is None
assert engine_kwargs["prompt_text"] == AUDIO_PROMPT

View File

@@ -88,9 +88,8 @@ class Qwen2VLTester:
# Validate outputs # Validate outputs
for generated, expected in zip(generated_texts, expected_outputs): for generated, expected in zip(generated_texts, expected_outputs):
assert expected.startswith(generated), ( assert expected.startswith(generated), (
f"Generated text {generated} doesn't " f"Generated text {generated} doesn't match expected pattern {expected}"
) )
f"match expected pattern {expected}"
def run_beam_search_test( def run_beam_search_test(
self, self,
@@ -118,11 +117,14 @@ class Qwen2VLTester:
inputs, beam_search_params, lora_request=lora_request inputs, beam_search_params, lora_request=lora_request
) )
for output_obj, expected_outs in zip(outputs, expected_outputs): for output_obj, expected_texts in zip(outputs, expected_outputs):
output_texts = [seq.text for seq in output_obj.sequences] output_texts = [seq.text for seq in output_obj.sequences]
assert output_texts == expected_outs, (
f"Generated texts {output_texts} do not match expected {expected_outs}" for output_text, expected_text in zip(output_texts, expected_texts):
) # noqa: E501 # NOTE beam search .text contains the whole text including inputs
assert output_text.endswith(expected_text), (
f"Generated {output_text} does not match expected {expected_text}"
)
TEST_IMAGES = [ TEST_IMAGES = [
@@ -151,11 +153,10 @@ EXPECTED_OUTPUTS_VISION_NO_CONNECTOR = [
"A closeup shot of the Tokyo Skytree with pink flowers in the foreground.", "A closeup shot of the Tokyo Skytree with pink flowers in the foreground.",
] ]
# NOTE - beam search .text contains the whole text
EXPECTED_BEAM_SEARCH_OUTPUTS = [ EXPECTED_BEAM_SEARCH_OUTPUTS = [
[ [
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic skyscraper stands", # noqa: E501 "A majestic skyscraper stands",
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is in the image?<|im_end|>\n<|im_start|>assistant\nA majestic tower stands tall", # noqa: E501 "A majestic tower stands tall",
], ],
] ]

View File

@@ -542,51 +542,31 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput) return self.engine_class.validate_outputs(outputs, RequestOutput)
def _resolve_lora_reqs( def _resolve_mm_lora(
self,
prompts: Sequence[ProcessorInputs],
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
):
lora_config = self.llm_engine.vllm_config.lora_config
seq_lora_requests = self._lora_request_to_seq(lora_request, len(prompts))
if (
lora_config is None
or not self.model_config.is_multimodal_model
or (lora_config and lora_config.default_mm_loras is None)
):
return seq_lora_requests
return [
self._resolve_single_prompt_mm_lora(
prompt,
lora_req,
lora_config.default_mm_loras,
)
for prompt, lora_req in zip(prompts, seq_lora_requests)
]
def _resolve_single_prompt_mm_lora(
self, self,
prompt: ProcessorInputs, prompt: ProcessorInputs,
lora_request: LoRARequest | None, lora_request: LoRARequest | None,
default_mm_loras: dict[str, str] | None, ) -> LoRARequest | None:
): if prompt["type"] != "multimodal":
if not default_mm_loras or prompt["type"] != "multimodal": return lora_request
lora_config = self.llm_engine.vllm_config.lora_config
default_mm_loras = None if lora_config is None else lora_config.default_mm_loras
if not default_mm_loras:
return lora_request return lora_request
prompt_modalities = prompt["mm_placeholders"].keys() prompt_modalities = prompt["mm_placeholders"].keys()
intersection = set(prompt_modalities).intersection(default_mm_loras.keys()) intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
if not intersection: if not intersection:
return lora_request return lora_request
if len(intersection) > 1: if len(intersection) > 1:
# TODO: Would be nice to be able to have multiple loras per prompt # TODO: Would be nice to be able to have multiple loras per prompt
logger.warning( logger.warning(
"Multiple modality specific loras were registered and would be" "Multiple modality specific loras were registered and would be "
" used by a single prompt consuming several modalities; " "used by a single prompt consuming several modalities; "
" currently we only support one lora per request; as such," "currently we only support one lora per request; as such, "
" lora(s) registered with modalities: %s" "lora(s) registered with modalities: %s will be skipped",
" will be skipped",
intersection, intersection,
) )
return lora_request return lora_request
@@ -1915,7 +1895,10 @@ class LLM:
request_id = self._add_request( request_id = self._add_request(
prompt, prompt,
params[i], params[i],
lora_request=None if lora_requests is None else lora_requests[i], lora_request=self._resolve_mm_lora(
prompt,
None if lora_requests is None else lora_requests[i],
),
priority=0 if priorities is None else priorities[i], priority=0 if priorities is None else priorities[i],
) )
added_request_ids.append(request_id) added_request_ids.append(request_id)