EAGLE 3: Fix preamble so that measured speedup over Eagle 1 becomes 32% instead of 5% on MTBench (#25916)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
This commit is contained in:
@@ -1151,6 +1151,12 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
|||||||
help="Do not oversample if the dataset has " \
|
help="Do not oversample if the dataset has " \
|
||||||
"fewer samples than num-prompts.",
|
"fewer samples than num-prompts.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-chat-template",
|
||||||
|
action="store_true",
|
||||||
|
help=
|
||||||
|
"Skip applying chat template to prompt for datasets that support it.",
|
||||||
|
)
|
||||||
|
|
||||||
# group for dataset specific arguments
|
# group for dataset specific arguments
|
||||||
custom_group = parser.add_argument_group("custom dataset options")
|
custom_group = parser.add_argument_group("custom dataset options")
|
||||||
@@ -1161,12 +1167,6 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
|||||||
help=
|
help=
|
||||||
"Number of output tokens per request, used only for custom dataset.",
|
"Number of output tokens per request, used only for custom dataset.",
|
||||||
)
|
)
|
||||||
custom_group.add_argument(
|
|
||||||
"--custom-skip-chat-template",
|
|
||||||
action="store_true",
|
|
||||||
help=
|
|
||||||
"Skip applying chat template to prompt, used only for custom dataset.",
|
|
||||||
)
|
|
||||||
|
|
||||||
spec_bench_group = parser.add_argument_group("spec bench dataset options")
|
spec_bench_group = parser.add_argument_group("spec bench dataset options")
|
||||||
spec_bench_group.add_argument(
|
spec_bench_group.add_argument(
|
||||||
@@ -1435,7 +1435,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
output_len=args.custom_output_len,
|
output_len=args.custom_output_len,
|
||||||
skip_chat_template=args.custom_skip_chat_template,
|
skip_chat_template=args.skip_chat_template,
|
||||||
request_id_prefix=args.request_id_prefix,
|
request_id_prefix=args.request_id_prefix,
|
||||||
no_oversample=args.no_oversample,
|
no_oversample=args.no_oversample,
|
||||||
)
|
)
|
||||||
@@ -1576,6 +1576,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
output_len=args.hf_output_len,
|
output_len=args.hf_output_len,
|
||||||
request_id_prefix=args.request_id_prefix,
|
request_id_prefix=args.request_id_prefix,
|
||||||
no_oversample=args.no_oversample,
|
no_oversample=args.no_oversample,
|
||||||
|
skip_chat_template=args.skip_chat_template,
|
||||||
**hf_kwargs
|
**hf_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1815,7 +1816,6 @@ class SpecBench(CustomDataset):
|
|||||||
|
|
||||||
def sample(self, **kwargs) -> list:
|
def sample(self, **kwargs) -> list:
|
||||||
# leverage CustomDataset sample
|
# leverage CustomDataset sample
|
||||||
kwargs["skip_chat_template"] = False
|
|
||||||
return super().sample(**kwargs)
|
return super().sample(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
@@ -2221,6 +2221,7 @@ class InstructCoderDataset(HuggingFaceDataset):
|
|||||||
num_requests: int,
|
num_requests: int,
|
||||||
output_len: Optional[int] = None,
|
output_len: Optional[int] = None,
|
||||||
enable_multimodal_chat: bool = False,
|
enable_multimodal_chat: bool = False,
|
||||||
|
skip_chat_template: bool = False,
|
||||||
request_id_prefix: str = "",
|
request_id_prefix: str = "",
|
||||||
no_oversample: bool = False,
|
no_oversample: bool = False,
|
||||||
**kwargs) -> list:
|
**kwargs) -> list:
|
||||||
@@ -2236,14 +2237,15 @@ class InstructCoderDataset(HuggingFaceDataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# apply template
|
# apply template
|
||||||
prompt = tokenizer.apply_chat_template(
|
if not skip_chat_template:
|
||||||
[{
|
prompt = tokenizer.apply_chat_template(
|
||||||
"role": "user",
|
[{
|
||||||
"content": prompt
|
"role": "user",
|
||||||
}],
|
"content": prompt
|
||||||
add_generation_prompt=True,
|
}],
|
||||||
tokenize=False,
|
add_generation_prompt=True,
|
||||||
)
|
tokenize=False,
|
||||||
|
)
|
||||||
|
|
||||||
prompt_len = len(tokenizer(prompt).input_ids)
|
prompt_len = len(tokenizer(prompt).input_ids)
|
||||||
sampled_requests.append(
|
sampled_requests.append(
|
||||||
@@ -2284,6 +2286,7 @@ class MTBenchDataset(HuggingFaceDataset):
|
|||||||
num_requests: int,
|
num_requests: int,
|
||||||
output_len: Optional[int] = None,
|
output_len: Optional[int] = None,
|
||||||
enable_multimodal_chat: bool = False,
|
enable_multimodal_chat: bool = False,
|
||||||
|
skip_chat_template: bool = False,
|
||||||
request_id_prefix: str = "",
|
request_id_prefix: str = "",
|
||||||
no_oversample: bool = False,
|
no_oversample: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -2298,14 +2301,15 @@ class MTBenchDataset(HuggingFaceDataset):
|
|||||||
prompt = item["turns"][0]
|
prompt = item["turns"][0]
|
||||||
|
|
||||||
# apply template
|
# apply template
|
||||||
prompt = tokenizer.apply_chat_template(
|
if not skip_chat_template:
|
||||||
[{
|
prompt = tokenizer.apply_chat_template(
|
||||||
"role": "user",
|
[{
|
||||||
"content": prompt
|
"role": "user",
|
||||||
}],
|
"content": prompt
|
||||||
add_generation_prompt=True,
|
}],
|
||||||
tokenize=False,
|
add_generation_prompt=True,
|
||||||
)
|
tokenize=False,
|
||||||
|
)
|
||||||
|
|
||||||
prompt_len = len(tokenizer(prompt).input_ids)
|
prompt_len = len(tokenizer(prompt).input_ids)
|
||||||
sampled_requests.append(
|
sampled_requests.append(
|
||||||
@@ -2349,6 +2353,7 @@ class BlazeditDataset(HuggingFaceDataset):
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
num_requests: int,
|
num_requests: int,
|
||||||
output_len: Optional[int] = None,
|
output_len: Optional[int] = None,
|
||||||
|
skip_chat_template: bool = False,
|
||||||
request_id_prefix: str = "",
|
request_id_prefix: str = "",
|
||||||
no_oversample: bool = False,
|
no_oversample: bool = False,
|
||||||
min_distance: float = 0.0,
|
min_distance: float = 0.0,
|
||||||
@@ -2372,7 +2377,7 @@ class BlazeditDataset(HuggingFaceDataset):
|
|||||||
|
|
||||||
# template copied from
|
# template copied from
|
||||||
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
|
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
|
||||||
instruction = f"""Given a code file, please apply the change requests and generate the new file.
|
prompt = f"""Given a code file, please apply the change requests and generate the new file.
|
||||||
|
|
||||||
Original file:
|
Original file:
|
||||||
```python
|
```python
|
||||||
@@ -2385,14 +2390,15 @@ Change request:
|
|||||||
Please generate the new code file in the "New file" section below.""" # noqa: E501
|
Please generate the new code file in the "New file" section below.""" # noqa: E501
|
||||||
|
|
||||||
# apply template
|
# apply template
|
||||||
prompt = tokenizer.apply_chat_template(
|
if not skip_chat_template:
|
||||||
[{
|
prompt = tokenizer.apply_chat_template(
|
||||||
"role": "user",
|
[{
|
||||||
"content": instruction
|
"role": "user",
|
||||||
}],
|
"content": prompt
|
||||||
add_generation_prompt=True,
|
}],
|
||||||
tokenize=False,
|
add_generation_prompt=True,
|
||||||
)
|
tokenize=False,
|
||||||
|
)
|
||||||
|
|
||||||
prompt_len = len(tokenizer(prompt).input_ids)
|
prompt_len = len(tokenizer(prompt).input_ids)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user