[Misc]add coding benchmark for speculative decoding (#15303)

Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
This commit is contained in:
Chen Xia
2025-03-27 19:47:05 -07:00
committed by GitHub
parent 4ae17bf1e2
commit e7f720ea56
3 changed files with 101 additions and 21 deletions

View File

@@ -715,3 +715,66 @@ class VisionArenaDataset(HuggingFaceDataset):
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
# -----------------------------------------------------------------------------
# Instruct Coder Dataset Implementation
# -----------------------------------------------------------------------------
class InstructCoderDataset(HuggingFaceDataset):
"""
InstructCoder Dataset.
https://huggingface.co/datasets/likaixin/InstructCoder
InstructCoder is the dataset designed for general code editing.
It consists of 114,239 instruction-input-output triplets,
and covers multiple distinct code editing scenario.
"""
DEFAULT_OUTPUT_LEN = 200 # this is the average default output length
DEFAULT_NUM_REQUESTS = 1000
INSTRUCT_CODER_DATASET_PATH = "likaixin/InstructCoder"
def __init__(
self,
**kwargs,
) -> None:
super().__init__(**kwargs)
if self.dataset_path != self.INSTRUCT_CODER_DATASET_PATH:
raise ValueError(f"Only support likaixin/InstructCoder dataset.\
This data path {self.dataset_path} is not valid.")
if self.dataset_subset is None and self.dataset_split != "train":
raise ValueError("Dataset split must be 'train'.")
def load_data(self) -> None:
dataset = load_dataset(
self.dataset_path,
name=self.dataset_subset,
split=self.dataset_split,
streaming=True,
)
self.data = dataset.shuffle(seed=self.random_seed)
def sample(self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
prompt = f"{item['instruction']}:\n{item['input']}"
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests