EAGLE 3: Fix preamble so that measured speedup over Eagle 1 becomes 32% instead of 5% on MTBench (#25916)

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
This commit is contained in:
Ekagra Ranjan
2025-10-02 14:29:35 -04:00
committed by GitHub
parent 1e50f1be70
commit 1cab2f9cad

View File

@@ -1151,6 +1151,12 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
help="Do not oversample if the dataset has " \ help="Do not oversample if the dataset has " \
"fewer samples than num-prompts.", "fewer samples than num-prompts.",
) )
parser.add_argument(
"--skip-chat-template",
action="store_true",
help=
"Skip applying chat template to prompt for datasets that support it.",
)
# group for dataset specific arguments # group for dataset specific arguments
custom_group = parser.add_argument_group("custom dataset options") custom_group = parser.add_argument_group("custom dataset options")
@@ -1161,12 +1167,6 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
help= help=
"Number of output tokens per request, used only for custom dataset.", "Number of output tokens per request, used only for custom dataset.",
) )
custom_group.add_argument(
"--custom-skip-chat-template",
action="store_true",
help=
"Skip applying chat template to prompt, used only for custom dataset.",
)
spec_bench_group = parser.add_argument_group("spec bench dataset options") spec_bench_group = parser.add_argument_group("spec bench dataset options")
spec_bench_group.add_argument( spec_bench_group.add_argument(
@@ -1435,7 +1435,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
output_len=args.custom_output_len, output_len=args.custom_output_len,
skip_chat_template=args.custom_skip_chat_template, skip_chat_template=args.skip_chat_template,
request_id_prefix=args.request_id_prefix, request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample, no_oversample=args.no_oversample,
) )
@@ -1576,6 +1576,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
output_len=args.hf_output_len, output_len=args.hf_output_len,
request_id_prefix=args.request_id_prefix, request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample, no_oversample=args.no_oversample,
skip_chat_template=args.skip_chat_template,
**hf_kwargs **hf_kwargs
) )
@@ -1815,7 +1816,6 @@ class SpecBench(CustomDataset):
def sample(self, **kwargs) -> list: def sample(self, **kwargs) -> list:
# leverage CustomDataset sample # leverage CustomDataset sample
kwargs["skip_chat_template"] = False
return super().sample(**kwargs) return super().sample(**kwargs)
@@ -2221,6 +2221,7 @@ class InstructCoderDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
skip_chat_template: bool = False,
request_id_prefix: str = "", request_id_prefix: str = "",
no_oversample: bool = False, no_oversample: bool = False,
**kwargs) -> list: **kwargs) -> list:
@@ -2236,14 +2237,15 @@ class InstructCoderDataset(HuggingFaceDataset):
) )
# apply template # apply template
prompt = tokenizer.apply_chat_template( if not skip_chat_template:
[{ prompt = tokenizer.apply_chat_template(
"role": "user", [{
"content": prompt "role": "user",
}], "content": prompt
add_generation_prompt=True, }],
tokenize=False, add_generation_prompt=True,
) tokenize=False,
)
prompt_len = len(tokenizer(prompt).input_ids) prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append( sampled_requests.append(
@@ -2284,6 +2286,7 @@ class MTBenchDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
skip_chat_template: bool = False,
request_id_prefix: str = "", request_id_prefix: str = "",
no_oversample: bool = False, no_oversample: bool = False,
**kwargs, **kwargs,
@@ -2298,14 +2301,15 @@ class MTBenchDataset(HuggingFaceDataset):
prompt = item["turns"][0] prompt = item["turns"][0]
# apply template # apply template
prompt = tokenizer.apply_chat_template( if not skip_chat_template:
[{ prompt = tokenizer.apply_chat_template(
"role": "user", [{
"content": prompt "role": "user",
}], "content": prompt
add_generation_prompt=True, }],
tokenize=False, add_generation_prompt=True,
) tokenize=False,
)
prompt_len = len(tokenizer(prompt).input_ids) prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append( sampled_requests.append(
@@ -2349,6 +2353,7 @@ class BlazeditDataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
skip_chat_template: bool = False,
request_id_prefix: str = "", request_id_prefix: str = "",
no_oversample: bool = False, no_oversample: bool = False,
min_distance: float = 0.0, min_distance: float = 0.0,
@@ -2372,7 +2377,7 @@ class BlazeditDataset(HuggingFaceDataset):
# template copied from # template copied from
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501 # https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
instruction = f"""Given a code file, please apply the change requests and generate the new file. prompt = f"""Given a code file, please apply the change requests and generate the new file.
Original file: Original file:
```python ```python
@@ -2385,14 +2390,15 @@ Change request:
Please generate the new code file in the "New file" section below.""" # noqa: E501 Please generate the new code file in the "New file" section below.""" # noqa: E501
# apply template # apply template
prompt = tokenizer.apply_chat_template( if not skip_chat_template:
[{ prompt = tokenizer.apply_chat_template(
"role": "user", [{
"content": instruction "role": "user",
}], "content": prompt
add_generation_prompt=True, }],
tokenize=False, add_generation_prompt=True,
) tokenize=False,
)
prompt_len = len(tokenizer(prompt).input_ids) prompt_len = len(tokenizer(prompt).input_ids)