[Frontend][Multimodal] Allow skipping media data when UUIDs are provided. (#23950)

Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: Roger Wang <hey@rogerw.me>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
Chenheli Hua
2025-09-12 19:16:06 -07:00
committed by GitHub
parent 4fdd6f5cbf
commit 7f2ea7074e
9 changed files with 970 additions and 96 deletions

View File

@@ -1764,6 +1764,7 @@ def apply_image_repeat(
probs = [1.0 - image_repeat_prob, image_repeat_prob]
inputs = []
inputs_with_empty_media = []
cur_image = data
for i in range(num_prompts):
if image_repeat_prob is not None:
@@ -1774,14 +1775,25 @@ def apply_image_repeat(
new_val = (i // 256 // 256, i // 256, i % 256)
cur_image.putpixel((0, 0), new_val)
uuid = "uuid_{}".format(i)
inputs.append(
{
"prompt": prompts[i % len(prompts)],
"multi_modal_data": {modality: cur_image},
"multi_modal_uuids": {modality: uuid},
}
)
return inputs
inputs_with_empty_media.append(
{
"prompt": prompts[i % len(prompts)],
"multi_modal_data": {modality: None},
"multi_modal_uuids": {modality: uuid},
}
)
return inputs, inputs_with_empty_media
@contextmanager
@@ -1860,6 +1872,13 @@ def parse_args():
help="If True, then use different prompt (with the same multi-modal "
"data) for each request.",
)
parser.add_argument(
"--verify-mm-cache-hit-with-uuids",
action="store_true",
help="If True, will send all requests in a second batch with empty mm "
"data to verify cache hits with UUIDs.",
)
return parser.parse_args()
@@ -1903,26 +1922,48 @@ def main(args):
assert args.num_prompts > 0
if args.num_prompts == 1:
# Single inference
uuid = "uuid_0"
inputs = {
"prompt": prompts[0],
"multi_modal_data": {modality: data},
"multi_modal_uuids": {modality: uuid},
}
inputs_with_empty_media = {
"prompt": prompts[0],
"multi_modal_data": {modality: None},
"multi_modal_uuids": {modality: uuid},
}
else:
# Batch inference
if args.image_repeat_prob is not None:
# Repeat images with specified probability of "image_repeat_prob"
inputs = apply_image_repeat(
args.image_repeat_prob, args.num_prompts, data, prompts, modality
inputs, inputs_with_empty_media = apply_image_repeat(
args.image_repeat_prob,
args.num_prompts,
data,
prompts,
modality,
)
else:
# Use the same image for all prompts
inputs = [
{
"prompt": prompts[i % len(prompts)],
"multi_modal_data": {modality: data},
}
for i in range(args.num_prompts)
]
inputs = []
inputs_with_empty_media = []
for i in range(args.num_prompts):
uuid = "uuid_{}".format(i)
inputs.append(
{
"prompt": prompts[i % len(prompts)],
"multi_modal_data": {modality: data},
"multi_modal_uuids": {modality: uuid},
}
)
inputs_with_empty_media.append(
{
"prompt": prompts[i % len(prompts)],
"multi_modal_data": {modality: None},
"multi_modal_uuids": {modality: uuid},
}
)
# Add LoRA request if applicable
lora_request = (
@@ -1942,6 +1983,26 @@ def main(args):
print(generated_text)
print("-" * 50)
if args.verify_mm_cache_hit_with_uuids:
try:
# Verify cache hits with UUIDs
print(
"Sending a second batch of requests with empty media"
" and matching UUIDs."
)
outputs = llm.generate(
inputs_with_empty_media,
sampling_params=sampling_params,
lora_request=lora_request,
)
print("-" * 50)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
print("-" * 50)
except Exception as e:
print(f"Failed to verify cache hits with UUIDs. Error: {e}")
if __name__ == "__main__":
args = parse_args()