Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -37,7 +37,8 @@ class Qwen2VLTester:
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
|
||||
"\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
|
||||
"What is in the image?<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
def __init__(self, config: TestConfig):
|
||||
self.config = config
|
||||
@@ -56,68 +57,68 @@ class Qwen2VLTester:
|
||||
max_model_len=self.config.max_model_len,
|
||||
)
|
||||
|
||||
def run_test(self,
|
||||
images: list[ImageAsset],
|
||||
expected_outputs: list[str],
|
||||
lora_id: Optional[int] = None,
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 5):
|
||||
|
||||
def run_test(
|
||||
self,
|
||||
images: list[ImageAsset],
|
||||
expected_outputs: list[str],
|
||||
lora_id: Optional[int] = None,
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 5,
|
||||
):
|
||||
sampling_params = vllm.SamplingParams(
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
inputs = [{
|
||||
"prompt": self.PROMPT_TEMPLATE,
|
||||
"multi_modal_data": {
|
||||
"image": asset.pil_image
|
||||
},
|
||||
} for asset in images]
|
||||
|
||||
lora_request = LoRARequest(str(lora_id), lora_id,
|
||||
self.config.lora_path)
|
||||
outputs = self.llm.generate(inputs,
|
||||
sampling_params,
|
||||
lora_request=lora_request)
|
||||
generated_texts = [
|
||||
output.outputs[0].text.strip() for output in outputs
|
||||
inputs = [
|
||||
{
|
||||
"prompt": self.PROMPT_TEMPLATE,
|
||||
"multi_modal_data": {"image": asset.pil_image},
|
||||
}
|
||||
for asset in images
|
||||
]
|
||||
|
||||
lora_request = LoRARequest(str(lora_id), lora_id, self.config.lora_path)
|
||||
outputs = self.llm.generate(inputs, sampling_params, lora_request=lora_request)
|
||||
generated_texts = [output.outputs[0].text.strip() for output in outputs]
|
||||
|
||||
# Validate outputs
|
||||
for generated, expected in zip(generated_texts, expected_outputs):
|
||||
assert expected.startswith(
|
||||
generated), f"Generated text {generated} doesn't "
|
||||
assert expected.startswith(generated), (
|
||||
f"Generated text {generated} doesn't "
|
||||
)
|
||||
f"match expected pattern {expected}"
|
||||
|
||||
def run_beam_search_test(self,
|
||||
images: list[ImageAsset],
|
||||
expected_outputs: list[list[str]],
|
||||
lora_id: Optional[int] = None,
|
||||
temperature: float = 0,
|
||||
beam_width: int = 2,
|
||||
max_tokens: int = 5):
|
||||
def run_beam_search_test(
|
||||
self,
|
||||
images: list[ImageAsset],
|
||||
expected_outputs: list[list[str]],
|
||||
lora_id: Optional[int] = None,
|
||||
temperature: float = 0,
|
||||
beam_width: int = 2,
|
||||
max_tokens: int = 5,
|
||||
):
|
||||
beam_search_params = BeamSearchParams(
|
||||
beam_width=beam_width, max_tokens=max_tokens, temperature=temperature
|
||||
)
|
||||
|
||||
beam_search_params = BeamSearchParams(beam_width=beam_width,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature)
|
||||
inputs = [
|
||||
{
|
||||
"prompt": self.PROMPT_TEMPLATE,
|
||||
"multi_modal_data": {"image": asset.pil_image},
|
||||
}
|
||||
for asset in images
|
||||
]
|
||||
|
||||
inputs = [{
|
||||
"prompt": self.PROMPT_TEMPLATE,
|
||||
"multi_modal_data": {
|
||||
"image": asset.pil_image
|
||||
},
|
||||
} for asset in images]
|
||||
|
||||
lora_request = LoRARequest(str(lora_id), lora_id,
|
||||
self.config.lora_path)
|
||||
outputs = self.llm.beam_search(inputs,
|
||||
beam_search_params,
|
||||
lora_request=lora_request)
|
||||
lora_request = LoRARequest(str(lora_id), lora_id, self.config.lora_path)
|
||||
outputs = self.llm.beam_search(
|
||||
inputs, beam_search_params, lora_request=lora_request
|
||||
)
|
||||
|
||||
for output_obj, expected_outs in zip(outputs, expected_outputs):
|
||||
output_texts = [seq.text for seq in output_obj.sequences]
|
||||
assert output_texts == expected_outs, \
|
||||
f"Generated texts {output_texts} do not match expected {expected_outs}" # noqa: E501
|
||||
assert output_texts == expected_outs, (
|
||||
f"Generated texts {output_texts} do not match expected {expected_outs}"
|
||||
) # noqa: E501
|
||||
|
||||
|
||||
TEST_IMAGES = [
|
||||
@@ -144,27 +145,25 @@ QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
|
||||
@pytest.mark.xfail(
|
||||
current_platform.is_rocm(),
|
||||
reason="Qwen2-VL dependency xformers incompatible with ROCm")
|
||||
reason="Qwen2-VL dependency xformers incompatible with ROCm",
|
||||
)
|
||||
def test_qwen2vl_lora(qwen2vl_lora_files):
|
||||
"""Test Qwen 2.0 VL model with LoRA"""
|
||||
config = TestConfig(model_path=QWEN2VL_MODEL_PATH,
|
||||
lora_path=qwen2vl_lora_files)
|
||||
config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files)
|
||||
tester = Qwen2VLTester(config)
|
||||
|
||||
# Test with different LoRA IDs
|
||||
for lora_id in [1, 2]:
|
||||
tester.run_test(TEST_IMAGES,
|
||||
expected_outputs=EXPECTED_OUTPUTS,
|
||||
lora_id=lora_id)
|
||||
tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
current_platform.is_rocm(),
|
||||
reason="Qwen2-VL dependency xformers incompatible with ROCm")
|
||||
reason="Qwen2-VL dependency xformers incompatible with ROCm",
|
||||
)
|
||||
def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
|
||||
"""Test Qwen 2.0 VL model with LoRA through beam search."""
|
||||
config = TestConfig(model_path=QWEN2VL_MODEL_PATH,
|
||||
lora_path=qwen2vl_lora_files)
|
||||
config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files)
|
||||
tester = Qwen2VLTester(config)
|
||||
|
||||
# Test with different LoRA IDs
|
||||
@@ -176,7 +175,8 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
|
||||
tester.run_beam_search_test(
|
||||
[ImageAsset("cherry_blossom")],
|
||||
expected_outputs=EXPECTED_BEAM_SEARCH_OUTPUTS,
|
||||
lora_id=lora_id)
|
||||
lora_id=lora_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
@@ -185,12 +185,9 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
|
||||
)
|
||||
def test_qwen25vl_lora(qwen25vl_lora_files):
|
||||
"""Test Qwen 2.5 VL model with LoRA"""
|
||||
config = TestConfig(model_path=QWEN25VL_MODEL_PATH,
|
||||
lora_path=qwen25vl_lora_files)
|
||||
config = TestConfig(model_path=QWEN25VL_MODEL_PATH, lora_path=qwen25vl_lora_files)
|
||||
tester = Qwen2VLTester(config)
|
||||
|
||||
# Test with different LoRA IDs
|
||||
for lora_id in [1, 2]:
|
||||
tester.run_test(TEST_IMAGES,
|
||||
expected_outputs=EXPECTED_OUTPUTS,
|
||||
lora_id=lora_id)
|
||||
tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id)
|
||||
|
||||
Reference in New Issue
Block a user