[Frontend][Multimodal] Allow skipping media data when UUIDs are provided. (#23950)
Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: Chenheli Hua <huachenheli@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.me> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
@@ -1764,6 +1764,7 @@ def apply_image_repeat(
|
||||
probs = [1.0 - image_repeat_prob, image_repeat_prob]
|
||||
|
||||
inputs = []
|
||||
inputs_with_empty_media = []
|
||||
cur_image = data
|
||||
for i in range(num_prompts):
|
||||
if image_repeat_prob is not None:
|
||||
@@ -1774,14 +1775,25 @@ def apply_image_repeat(
|
||||
new_val = (i // 256 // 256, i // 256, i % 256)
|
||||
cur_image.putpixel((0, 0), new_val)
|
||||
|
||||
uuid = "uuid_{}".format(i)
|
||||
|
||||
inputs.append(
|
||||
{
|
||||
"prompt": prompts[i % len(prompts)],
|
||||
"multi_modal_data": {modality: cur_image},
|
||||
"multi_modal_uuids": {modality: uuid},
|
||||
}
|
||||
)
|
||||
|
||||
return inputs
|
||||
inputs_with_empty_media.append(
|
||||
{
|
||||
"prompt": prompts[i % len(prompts)],
|
||||
"multi_modal_data": {modality: None},
|
||||
"multi_modal_uuids": {modality: uuid},
|
||||
}
|
||||
)
|
||||
|
||||
return inputs, inputs_with_empty_media
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -1860,6 +1872,13 @@ def parse_args():
|
||||
help="If True, then use different prompt (with the same multi-modal "
|
||||
"data) for each request.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verify-mm-cache-hit-with-uuids",
|
||||
action="store_true",
|
||||
help="If True, will send all requests in a second batch with empty mm "
|
||||
"data to verify cache hits with UUIDs.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -1903,26 +1922,48 @@ def main(args):
|
||||
assert args.num_prompts > 0
|
||||
if args.num_prompts == 1:
|
||||
# Single inference
|
||||
uuid = "uuid_0"
|
||||
inputs = {
|
||||
"prompt": prompts[0],
|
||||
"multi_modal_data": {modality: data},
|
||||
"multi_modal_uuids": {modality: uuid},
|
||||
}
|
||||
inputs_with_empty_media = {
|
||||
"prompt": prompts[0],
|
||||
"multi_modal_data": {modality: None},
|
||||
"multi_modal_uuids": {modality: uuid},
|
||||
}
|
||||
else:
|
||||
# Batch inference
|
||||
if args.image_repeat_prob is not None:
|
||||
# Repeat images with specified probability of "image_repeat_prob"
|
||||
inputs = apply_image_repeat(
|
||||
args.image_repeat_prob, args.num_prompts, data, prompts, modality
|
||||
inputs, inputs_with_empty_media = apply_image_repeat(
|
||||
args.image_repeat_prob,
|
||||
args.num_prompts,
|
||||
data,
|
||||
prompts,
|
||||
modality,
|
||||
)
|
||||
else:
|
||||
# Use the same image for all prompts
|
||||
inputs = [
|
||||
{
|
||||
"prompt": prompts[i % len(prompts)],
|
||||
"multi_modal_data": {modality: data},
|
||||
}
|
||||
for i in range(args.num_prompts)
|
||||
]
|
||||
inputs = []
|
||||
inputs_with_empty_media = []
|
||||
for i in range(args.num_prompts):
|
||||
uuid = "uuid_{}".format(i)
|
||||
inputs.append(
|
||||
{
|
||||
"prompt": prompts[i % len(prompts)],
|
||||
"multi_modal_data": {modality: data},
|
||||
"multi_modal_uuids": {modality: uuid},
|
||||
}
|
||||
)
|
||||
inputs_with_empty_media.append(
|
||||
{
|
||||
"prompt": prompts[i % len(prompts)],
|
||||
"multi_modal_data": {modality: None},
|
||||
"multi_modal_uuids": {modality: uuid},
|
||||
}
|
||||
)
|
||||
|
||||
# Add LoRA request if applicable
|
||||
lora_request = (
|
||||
@@ -1942,6 +1983,26 @@ def main(args):
|
||||
print(generated_text)
|
||||
print("-" * 50)
|
||||
|
||||
if args.verify_mm_cache_hit_with_uuids:
|
||||
try:
|
||||
# Verify cache hits with UUIDs
|
||||
print(
|
||||
"Sending a second batch of requests with empty media"
|
||||
" and matching UUIDs."
|
||||
)
|
||||
outputs = llm.generate(
|
||||
inputs_with_empty_media,
|
||||
sampling_params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
print("-" * 50)
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
print("-" * 50)
|
||||
except Exception as e:
|
||||
print(f"Failed to verify cache hits with UUIDs. Error: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
Reference in New Issue
Block a user