[Misc] support arbitrary MM datasets in spec dec bench (#33486)

Signed-off-by: kkt-cohere <komal@cohere.com>
Signed-off-by: Komal Kumar Teru <162363718+kkt-cohere@users.noreply.github.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Komal Kumar Teru
2026-02-02 14:19:48 +05:30
committed by GitHub
parent ab374786c7
commit ba871fb788
3 changed files with 156 additions and 18 deletions

View File

@@ -5,7 +5,6 @@ from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.benchmarks.datasets import add_dataset_parser, get_samples
from vllm.inputs import TokensPrompt
from vllm.v1.metrics.reader import Counter, Vector
try:
@@ -56,6 +55,7 @@ def parse_args():
default="eagle",
choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"],
)
parser.add_argument("--backend", type=str, default="openai")
parser.add_argument("--num-spec-tokens", type=int, default=2)
parser.add_argument("--prompt-lookup-max", type=int, default=5)
parser.add_argument("--prompt-lookup-min", type=int, default=2)
@@ -75,12 +75,11 @@ def parse_args():
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
parser.add_argument("--disable-padded-drafter-batch", action="store_true")
parser.add_argument("--max-num-seqs", type=int, default=None)
parser.add_argument("--allowed-local-media-path", type=str, default="")
return parser.parse_args()
def main(args):
args.endpoint_type = "openai-chat"
model_dir = args.model_dir
if args.model_dir is None:
if args.custom_mm_prompts:
@@ -91,19 +90,25 @@ def main(args):
)
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
args.custom_skip_chat_template = True
if not args.custom_mm_prompts:
prompts = get_samples(args, tokenizer)
# add_special_tokens is False to avoid adding bos twice
# when using chat templates
prompt_ids = [
tokenizer.encode(prompt.prompt, add_special_tokens=False)
for prompt in prompts
]
if args.custom_mm_prompts:
prompts = llm_prompts = get_custom_mm_prompts(args.num_prompts)
else:
prompts = get_custom_mm_prompts(args.num_prompts)
prompts = get_samples(args, tokenizer)
if args.enable_multimodal_chat:
llm_prompts = [p.prompt for p in prompts]
else:
# add_special_tokens is False to avoid adding bos twice
# when using chat templates
llm_prompts = [
{
"prompt_token_ids": tokenizer.encode(
prompt.prompt, add_special_tokens=False
),
"multi_modal_data": prompt.multi_modal_data,
}
for prompt in prompts
]
if args.method == "eagle" or args.method == "eagle3":
eagle_dir = args.eagle_dir
if args.method == "eagle" and eagle_dir is None:
@@ -154,16 +159,17 @@ def main(args):
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
max_num_seqs=args.max_num_seqs,
allowed_local_media_path=args.allowed_local_media_path,
)
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
if not args.custom_mm_prompts:
if args.backend == "openai-chat":
outputs = llm.chat(llm_prompts, sampling_params=sampling_params)
else:
outputs = llm.generate(
[TokensPrompt(prompt_token_ids=x) for x in prompt_ids],
llm_prompts,
sampling_params=sampling_params,
)
else:
outputs = llm.chat(prompts, sampling_params=sampling_params)
# print the generated text
if args.print_output:
@@ -219,6 +225,8 @@ def main(args):
if __name__ == "__main__":
args = parse_args()
args.enable_multimodal_chat = args.backend == "openai-chat"
acceptance_length = main(args)
if args.test: