2025-03-10 00:23:11 -07:00
# SPDX-License-Identifier: Apache-2.0
2025-06-03 11:20:17 -07:00
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2025-03-10 00:23:11 -07:00
"""
This module defines a framework for sampling benchmark requests from various
datasets . Each dataset subclass of BenchmarkDataset must implement sample
generation . Supported dataset types include :
- ShareGPT
- Random ( synthetic )
- Sonnet
- BurstGPT
- HuggingFace
- VisionArena
"""
import base64
import io
import json
2025-03-19 21:32:58 -07:00
import logging
2025-03-10 00:23:11 -07:00
import random
from abc import ABC , abstractmethod
from collections . abc import Mapping
2025-08-19 04:32:18 -04:00
from copy import deepcopy
2025-03-10 00:23:11 -07:00
from dataclasses import dataclass
from functools import cache
2025-03-31 00:38:58 -07:00
from io import BytesIO
from typing import Any , Callable , Optional , Union
2025-03-10 00:23:11 -07:00
import numpy as np
import pandas as pd
from datasets import load_dataset
from PIL import Image
from transformers import PreTrainedTokenizerBase
from vllm . lora . request import LoRARequest
from vllm . lora . utils import get_adapter_absolute_path
from vllm . multimodal import MultiModalDataDict
2025-05-22 18:59:18 -07:00
from vllm . multimodal . image import convert_image_mode
2025-03-10 00:23:11 -07:00
from vllm . transformers_utils . tokenizer import AnyTokenizer , get_lora_tokenizer
2025-03-19 21:32:58 -07:00
logger = logging . getLogger ( __name__ )
2025-03-10 00:23:11 -07:00
# -----------------------------------------------------------------------------
# Data Classes
# -----------------------------------------------------------------------------
@dataclass
class SampleRequest :
"""
Represents a single inference request for benchmarking .
"""
2025-03-13 21:07:54 -07:00
prompt : Union [ str , Any ]
2025-03-10 00:23:11 -07:00
prompt_len : int
expected_output_len : int
2025-08-10 18:03:15 +02:00
multi_modal_data : Optional [ Union [ MultiModalDataDict , dict , list [ dict ] ] ] = None
2025-03-10 00:23:11 -07:00
lora_request : Optional [ LoRARequest ] = None
2025-08-19 04:32:18 -04:00
request_id : Optional [ str ] = None
2025-03-10 00:23:11 -07:00
# -----------------------------------------------------------------------------
# Benchmark Dataset Base Class
# -----------------------------------------------------------------------------
class BenchmarkDataset ( ABC ) :
DEFAULT_SEED = 0
2025-04-19 11:24:14 +02:00
IS_MULTIMODAL = False
2025-03-10 00:23:11 -07:00
def __init__ (
self ,
dataset_path : Optional [ str ] = None ,
random_seed : int = DEFAULT_SEED ,
) - > None :
"""
Initialize the BenchmarkDataset with an optional dataset path and random
seed . Args :
dataset_path ( Optional [ str ] ) : Path to the dataset . If None , it
indicates that a default or random dataset might be used .
random_seed ( int ) : Seed value for reproducible shuffling or
sampling . Defaults to DEFAULT_SEED .
"""
self . dataset_path = dataset_path
# Set the random seed, ensuring that a None value is replaced with the
# default seed.
2025-05-13 14:43:29 +01:00
self . random_seed = random_seed if random_seed is not None else self . DEFAULT_SEED
2025-03-10 00:23:11 -07:00
self . data = None
2025-03-13 21:07:54 -07:00
def apply_multimodal_chat_transformation (
2025-05-13 14:43:29 +01:00
self , prompt : str , mm_content : Optional [ MultiModalDataDict ] = None
) - > list [ dict ] :
2025-03-13 21:07:54 -07:00
"""
Transform a prompt and optional multimodal content into a chat format .
2025-03-19 21:32:58 -07:00
This method is used for chat models that expect a specific conversation
format .
2025-03-13 21:07:54 -07:00
"""
content = [ { " text " : prompt , " type " : " text " } ]
if mm_content is not None :
content . append ( mm_content )
return [ { " role " : " user " , " content " : content } ]
2025-03-10 00:23:11 -07:00
def load_data ( self ) - > None :
"""
Load data from the dataset path into self . data .
2025-03-19 21:32:58 -07:00
2025-03-10 00:23:11 -07:00
This method must be overridden by subclasses since the method to load
data will vary depending on the dataset format and source .
2025-03-19 21:32:58 -07:00
2025-03-10 00:23:11 -07:00
Raises :
NotImplementedError : If a subclass does not implement this method .
"""
# TODO (jenniferzhao): add support for downloading data
2025-05-13 14:43:29 +01:00
raise NotImplementedError ( " load_data must be implemented in subclasses. " )
2025-03-10 00:23:11 -07:00
def get_random_lora_request (
self ,
tokenizer : PreTrainedTokenizerBase ,
max_loras : Optional [ int ] = None ,
lora_path : Optional [ str ] = None ,
) - > tuple [ Optional [ LoRARequest ] , AnyTokenizer ] :
"""
Optionally select a random LoRA request and return its associated
tokenizer .
2025-03-19 21:32:58 -07:00
2025-03-10 00:23:11 -07:00
This method is used when LoRA parameters are provided . It randomly
selects a LoRA based on max_loras and retrieves a cached tokenizer for
that LoRA if available . Otherwise , it returns the base tokenizer .
2025-03-19 21:32:58 -07:00
2025-03-10 00:23:11 -07:00
Args :
tokenizer ( PreTrainedTokenizerBase ) : The base tokenizer to use if no
LoRA is selected . max_loras ( Optional [ int ] ) : The maximum number of
LoRAs available . If None , LoRA is not used . lora_path
( Optional [ str ] ) : Path to the LoRA parameters on disk . If None , LoRA
is not used .
2025-03-19 21:32:58 -07:00
2025-03-10 00:23:11 -07:00
Returns :
tuple [ Optional [ LoRARequest ] , AnyTokenizer ] : A tuple where the first
element is a LoRARequest ( or None if not applicable ) and the second
element is the tokenizer associated with the LoRA request ( or the
base tokenizer ) .
"""
if max_loras is None or lora_path is None :
return None , tokenizer
# Generate a random LoRA ID in the range [1, max_loras].
lora_id = random . randint ( 1 , max_loras )
lora_request = LoRARequest (
lora_name = str ( lora_id ) ,
lora_int_id = lora_id ,
lora_path = lora_path_on_disk ( lora_path ) ,
)
if lora_id not in lora_tokenizer_cache :
lora_tokenizer_cache [ lora_id ] = get_lora_tokenizer ( lora_request )
# Return lora_request and the cached tokenizer if available; otherwise,
# return the base tokenizer
return lora_request , lora_tokenizer_cache [ lora_id ] or tokenizer
@abstractmethod
2025-05-13 14:43:29 +01:00
def sample (
2025-08-19 04:32:18 -04:00
self ,
tokenizer : PreTrainedTokenizerBase ,
num_requests : int ,
request_id_prefix : str = " " ,
2025-05-13 14:43:29 +01:00
) - > list [ SampleRequest ] :
2025-03-10 00:23:11 -07:00
"""
Abstract method to generate sample requests from the dataset .
2025-03-19 21:32:58 -07:00
2025-03-10 00:23:11 -07:00
Subclasses must override this method to implement dataset - specific logic
for generating a list of SampleRequest objects .
2025-03-19 21:32:58 -07:00
2025-03-10 00:23:11 -07:00
Args :
tokenizer ( PreTrainedTokenizerBase ) : The tokenizer to be used
for processing the dataset ' s text.
num_requests ( int ) : The number of sample requests to generate .
2025-08-19 04:32:18 -04:00
request_id_prefix ( str ) The prefix of request_id .
2025-03-19 21:32:58 -07:00
2025-03-10 00:23:11 -07:00
Returns :
list [ SampleRequest ] : A list of sample requests generated from the
dataset .
"""
raise NotImplementedError ( " sample must be implemented in subclasses. " )
2025-05-13 14:43:29 +01:00
def maybe_oversample_requests (
2025-08-19 04:32:18 -04:00
self ,
requests : list [ SampleRequest ] ,
num_requests : int ,
request_id_prefix : str = " " ,
2025-05-13 14:43:29 +01:00
) - > None :
2025-03-19 21:32:58 -07:00
"""
Oversamples the list of requests if its size is less than the desired
number .
Args :
requests ( List [ SampleRequest ] ) : The current list of sampled
2025-08-19 04:32:18 -04:00
requests .
num_requests ( int ) : The target number of requests .
request_id_prefix ( str ) The prefix of the request ids .
2025-03-19 21:32:58 -07:00
"""
if len ( requests ) < num_requests :
random . seed ( self . random_seed )
2025-08-19 04:32:18 -04:00
additional = deepcopy (
random . choices ( requests , k = num_requests - len ( requests ) )
)
for i in range ( len ( additional ) ) :
req = additional [ i ]
req . request_id = request_id_prefix + str ( len ( requests ) + i )
2025-03-19 21:32:58 -07:00
requests . extend ( additional )
2025-05-13 14:43:29 +01:00
logger . info ( " Oversampled requests to reach %d total samples. " , num_requests )
2025-03-19 21:32:58 -07:00
2025-03-10 00:23:11 -07:00
# -----------------------------------------------------------------------------
# Utility Functions and Global Caches
# -----------------------------------------------------------------------------
def is_valid_sequence (
prompt_len : int ,
output_len : int ,
min_len : int = 4 ,
max_prompt_len : int = 1024 ,
max_total_len : int = 2048 ,
skip_min_output_len_check : bool = False ,
) - > bool :
"""
Validate a sequence based on prompt and output lengths .
Default pruning criteria are copied from the original ` sample_hf_requests `
and ` sample_sharegpt_requests ` functions in benchmark_serving . py , as well as
from ` sample_requests ` in benchmark_throughput . py .
"""
# Check for invalid conditions
prompt_too_short = prompt_len < min_len
2025-05-13 14:43:29 +01:00
output_too_short = ( not skip_min_output_len_check ) and ( output_len < min_len )
2025-03-10 00:23:11 -07:00
prompt_too_long = prompt_len > max_prompt_len
combined_too_long = ( prompt_len + output_len ) > max_total_len
# Return True if none of the invalid conditions are met
2025-05-13 14:43:29 +01:00
return not (
prompt_too_short or output_too_short or prompt_too_long or combined_too_long
)
2025-03-10 00:23:11 -07:00
@cache
def lora_path_on_disk ( lora_path : str ) - > str :
return get_adapter_absolute_path ( lora_path )
# Global cache for LoRA tokenizers.
lora_tokenizer_cache : dict [ int , AnyTokenizer ] = { }
def process_image ( image : Any ) - > Mapping [ str , Any ] :
"""
Process a single image input and return a multimedia content dictionary .
2025-03-31 00:38:58 -07:00
Supports three input types :
2025-03-10 00:23:11 -07:00
2025-03-31 00:38:58 -07:00
1. Dictionary with raw image bytes : - Expects a dict with a ' bytes ' key
containing raw image data . - Loads the bytes as a PIL . Image . Image .
2. PIL . Image . Image input : - Converts the image to RGB . - Saves the image as
a JPEG in memory . - Encodes the JPEG data as a base64 string . - Returns
a dictionary with the image as a base64 data URL .
3. String input : - Treats the string as a URL or local file path . -
Prepends " file:// " if the string doesn ' t start with " http:// " or
" file:// " . - Returns a dictionary with the image URL .
2025-03-10 00:23:11 -07:00
Raises :
2025-03-31 00:38:58 -07:00
ValueError : If the input is not a supported type .
2025-03-10 00:23:11 -07:00
"""
2025-05-13 14:43:29 +01:00
if isinstance ( image , dict ) and " bytes " in image :
image = Image . open ( BytesIO ( image [ " bytes " ] ) )
2025-03-10 00:23:11 -07:00
if isinstance ( image , Image . Image ) :
2025-05-22 18:59:18 -07:00
image = convert_image_mode ( image , " RGB " )
2025-03-10 00:23:11 -07:00
with io . BytesIO ( ) as image_data :
image . save ( image_data , format = " JPEG " )
2025-05-13 14:43:29 +01:00
image_base64 = base64 . b64encode ( image_data . getvalue ( ) ) . decode ( " utf-8 " )
2025-03-10 00:23:11 -07:00
return {
" type " : " image_url " ,
2025-05-13 14:43:29 +01:00
" image_url " : { " url " : f " data:image/jpeg;base64, { image_base64 } " } ,
2025-03-10 00:23:11 -07:00
}
if isinstance ( image , str ) :
2025-05-13 14:43:29 +01:00
image_url = (
image if image . startswith ( ( " http:// " , " file:// " ) ) else f " file:// { image } "
)
2025-03-10 00:23:11 -07:00
return { " type " : " image_url " , " image_url " : { " url " : image_url } }
2025-05-13 14:43:29 +01:00
raise ValueError (
f " Invalid image input { image } . Must be a PIL.Image.Image "
" or str or dictionary with raw image bytes. "
)
2025-03-10 00:23:11 -07:00
2025-08-19 16:42:31 -07:00
def process_video ( video : Any ) - > Mapping [ str , Any ] :
"""
Process a single video input and return a multimedia content dictionary .
Supports the following input types :
1. Dictionary with raw video bytes : - Expects a dict with a ' bytes ' key
containing raw video data .
2. String input : - Treats the string as a URL or local file path . -
Prepends " file:// " if the string doesn ' t start with " http:// " or
" file:// " . - Returns a dictionary with the image URL .
Raises :
ValueError : If the input is not a supported type .
"""
if isinstance ( video , dict ) and " bytes " in video :
video_bytes = video [ " bytes " ]
video_base64 = base64 . b64encode ( video_bytes ) . decode ( " utf-8 " )
return {
" type " : " video_url " ,
" video_url " : { " url " : f " data:video/mp4;base64, { video_base64 } " } ,
}
if isinstance ( video , str ) :
video_url = (
video if video . startswith ( ( " http:// " , " file:// " ) ) else f " file:// { video } "
)
return { " type " : " video_url " , " video_url " : { " url " : video_url } }
raise ValueError (
f " Invalid video input { video } . Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of ` {{ ' bytes ' : raw_video_bytes }} `. " # noqa: E501
)
2025-03-10 00:23:11 -07:00
# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------
class RandomDataset ( BenchmarkDataset ) :
# Default values copied from benchmark_serving.py for the random dataset.
DEFAULT_PREFIX_LEN = 0
2025-04-11 06:31:17 +08:00
DEFAULT_RANGE_RATIO = 0.0
2025-03-10 00:23:11 -07:00
DEFAULT_INPUT_LEN = 1024
DEFAULT_OUTPUT_LEN = 128
def __init__ (
self ,
* * kwargs ,
) - > None :
super ( ) . __init__ ( * * kwargs )
2025-03-19 21:32:58 -07:00
def sample (
self ,
tokenizer : PreTrainedTokenizerBase ,
num_requests : int ,
prefix_len : int = DEFAULT_PREFIX_LEN ,
range_ratio : float = DEFAULT_RANGE_RATIO ,
input_len : int = DEFAULT_INPUT_LEN ,
output_len : int = DEFAULT_OUTPUT_LEN ,
2025-08-19 04:32:18 -04:00
request_id_prefix : str = " " ,
2025-03-19 21:32:58 -07:00
* * kwargs ,
) - > list [ SampleRequest ] :
2025-04-11 06:31:17 +08:00
# Enforce range_ratio < 1
assert range_ratio < 1.0 , (
" random_range_ratio must be < 1.0 to ensure a valid sampling range "
)
2025-03-10 00:23:11 -07:00
vocab_size = tokenizer . vocab_size
2025-05-06 09:00:08 +02:00
num_special_tokens = tokenizer . num_special_tokens_to_add ( )
real_input_len = input_len - num_special_tokens
2025-03-10 00:23:11 -07:00
2025-05-13 14:43:29 +01:00
prefix_token_ids = (
np . random . randint ( 0 , vocab_size , size = prefix_len ) . tolist ( )
if prefix_len > 0
else [ ]
)
2025-03-10 00:23:11 -07:00
2025-04-11 06:31:17 +08:00
# New sampling logic: [X * (1 - b), X * (1 + b)]
2025-05-06 09:00:08 +02:00
input_low = int ( real_input_len * ( 1 - range_ratio ) )
input_high = int ( real_input_len * ( 1 + range_ratio ) )
2025-04-11 06:31:17 +08:00
output_low = int ( output_len * ( 1 - range_ratio ) )
2025-07-10 14:52:46 -07:00
# Ensure the lower bound for output length is at least 1 to prevent
# sampling 0 tokens, which can cause request failures.
output_low = max ( output_low , 1 )
2025-04-11 06:31:17 +08:00
output_high = int ( output_len * ( 1 + range_ratio ) )
# Add logging for debugging
logger . info ( " Sampling input_len from [ %s , %s ] " , input_low , input_high )
2025-05-13 14:43:29 +01:00
logger . info ( " Sampling output_len from [ %s , %s ] " , output_low , output_high )
input_lens = np . random . randint ( input_low , input_high + 1 , size = num_requests )
output_lens = np . random . randint ( output_low , output_high + 1 , size = num_requests )
2025-03-10 00:23:11 -07:00
offsets = np . random . randint ( 0 , vocab_size , size = num_requests )
requests = [ ]
for i in range ( num_requests ) :
2025-05-13 14:43:29 +01:00
inner_seq = (
( offsets [ i ] + i + np . arange ( input_lens [ i ] ) ) % vocab_size
) . tolist ( )
2025-03-10 00:23:11 -07:00
token_sequence = prefix_token_ids + inner_seq
prompt = tokenizer . decode ( token_sequence )
2025-05-06 09:00:08 +02:00
# After decoding the prompt we have to encode and decode it again.
# This is done because in some cases N consecutive tokens
# give a string tokenized into != N number of tokens.
# For example for GPT2Tokenizer:
# [6880, 6881] -> ['Ġcalls', 'here'] ->
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
# To avoid uncontrolled change of the prompt length,
# the encoded sequence is truncated before being decode again.
2025-06-26 01:06:27 -04:00
total_input_len = prefix_len + int ( input_lens [ i ] )
2025-05-13 14:43:29 +01:00
re_encoded_sequence = tokenizer . encode ( prompt , add_special_tokens = False ) [
2025-06-26 01:06:27 -04:00
: total_input_len
2025-05-13 14:43:29 +01:00
]
2025-05-06 09:00:08 +02:00
prompt = tokenizer . decode ( re_encoded_sequence )
2025-06-19 21:30:41 -04:00
total_input_len = len ( re_encoded_sequence )
2025-03-10 00:23:11 -07:00
requests . append (
SampleRequest (
prompt = prompt ,
prompt_len = total_input_len ,
expected_output_len = int ( output_lens [ i ] ) ,
2025-08-19 04:32:18 -04:00
request_id = request_id_prefix + str ( i ) ,
2025-05-13 14:43:29 +01:00
)
)
2025-08-19 04:32:18 -04:00
2025-03-10 00:23:11 -07:00
return requests
# -----------------------------------------------------------------------------
# ShareGPT Dataset Implementation
# -----------------------------------------------------------------------------
class ShareGPTDataset ( BenchmarkDataset ) :
"""
Implements the ShareGPT dataset . Loads data from a JSON file and generates
sample requests based on conversation turns .
"""
def __init__ ( self , * * kwargs ) - > None :
super ( ) . __init__ ( * * kwargs )
self . load_data ( )
def load_data ( self ) - > None :
if self . dataset_path is None :
raise ValueError ( " dataset_path must be provided for loading data. " )
with open ( self . dataset_path , encoding = " utf-8 " ) as f :
self . data = json . load ( f )
# Filter entries with at least two conversation turns.
self . data = [
2025-05-13 14:43:29 +01:00
entry
for entry in self . data
2025-03-10 00:23:11 -07:00
if " conversations " in entry and len ( entry [ " conversations " ] ) > = 2
]
random . seed ( self . random_seed )
random . shuffle ( self . data )
2025-03-19 21:32:58 -07:00
def sample (
self ,
tokenizer : PreTrainedTokenizerBase ,
num_requests : int ,
lora_path : Optional [ str ] = None ,
max_loras : Optional [ int ] = None ,
output_len : Optional [ int ] = None ,
enable_multimodal_chat : bool = False ,
2025-08-19 04:32:18 -04:00
request_id_prefix : str = " " ,
2025-03-19 21:32:58 -07:00
* * kwargs ,
) - > list :
2025-03-10 00:23:11 -07:00
samples : list = [ ]
2025-08-19 04:32:18 -04:00
ind = 0
2025-03-10 00:23:11 -07:00
for entry in self . data :
if len ( samples ) > = num_requests :
break
2025-03-19 21:32:58 -07:00
prompt , completion = (
entry [ " conversations " ] [ 0 ] [ " value " ] ,
entry [ " conversations " ] [ 1 ] [ " value " ] ,
)
2025-03-10 00:23:11 -07:00
lora_request , tokenizer = self . get_random_lora_request (
2025-05-13 14:43:29 +01:00
tokenizer = tokenizer , max_loras = max_loras , lora_path = lora_path
)
2025-03-10 00:23:11 -07:00
prompt_ids = tokenizer ( prompt ) . input_ids
completion_ids = tokenizer ( completion ) . input_ids
prompt_len = len ( prompt_ids )
2025-05-13 14:43:29 +01:00
new_output_len = len ( completion_ids ) if output_len is None else output_len
if not is_valid_sequence (
prompt_len ,
new_output_len ,
skip_min_output_len_check = output_len is not None ,
) :
2025-03-10 00:23:11 -07:00
continue
2025-08-15 11:23:06 -07:00
if image_path := entry . get ( " image " ) :
mm_content = process_image ( image_path )
2025-08-19 16:42:31 -07:00
elif video_path := entry . get ( " video " ) :
mm_content = process_video ( video_path )
2025-08-15 11:23:06 -07:00
else :
mm_content = None
2025-03-13 21:07:54 -07:00
if enable_multimodal_chat :
2025-08-15 11:23:06 -07:00
prompt = self . apply_multimodal_chat_transformation ( prompt , mm_content )
2025-03-10 00:23:11 -07:00
samples . append (
SampleRequest (
prompt = prompt ,
prompt_len = prompt_len ,
expected_output_len = new_output_len ,
lora_request = lora_request ,
2025-08-15 11:23:06 -07:00
multi_modal_data = mm_content ,
2025-08-19 04:32:18 -04:00
request_id = request_id_prefix + str ( ind ) ,
2025-05-13 14:43:29 +01:00
)
)
2025-08-19 04:32:18 -04:00
ind + = 1
self . maybe_oversample_requests ( samples , num_requests , request_id_prefix )
2025-03-10 00:23:11 -07:00
return samples
2025-05-31 15:07:38 -04:00
# -----------------------------------------------------------------------------
# Custom Dataset Implementation
# -----------------------------------------------------------------------------
class CustomDataset ( BenchmarkDataset ) :
"""
Implements the Custom dataset . Loads data from a JSONL file and generates
sample requests based on conversation turns . E . g . ,
` ` `
{ " prompt " : " What is the capital of India? " }
{ " prompt " : " What is the capital of Iran? " }
{ " prompt " : " What is the capital of China? " }
` ` `
"""
def __init__ ( self , * * kwargs ) - > None :
super ( ) . __init__ ( * * kwargs )
self . load_data ( )
def load_data ( self ) - > None :
if self . dataset_path is None :
raise ValueError ( " dataset_path must be provided for loading data. " )
# self.data will be a list of dictionaries
# e.g., [{"prompt": "What is the capital of India?"}, ...]
# This will be the standardized format which load_data()
# has to convert into depending on the filetype of dataset_path.
# sample() will assume this standardized format of self.data
self . data = [ ]
# Load the JSONL file
if self . dataset_path . endswith ( " .jsonl " ) :
jsonl_data = pd . read_json ( path_or_buf = self . dataset_path , lines = True )
# check if the JSONL file has a 'prompt' column
if " prompt " not in jsonl_data . columns :
raise ValueError ( " JSONL file must contain a ' prompt ' column. " )
# Convert each row to a dictionary and append to self.data
# This will convert the DataFrame to a list of dictionaries
# where each dictionary corresponds to a row in the DataFrame.
# This is the standardized format we want for self.data
for _ , row in jsonl_data . iterrows ( ) :
self . data . append ( row . to_dict ( ) )
else :
raise NotImplementedError (
" Only JSONL format is supported for CustomDataset. "
)
random . seed ( self . random_seed )
random . shuffle ( self . data )
def sample (
self ,
tokenizer : PreTrainedTokenizerBase ,
num_requests : int ,
lora_path : Optional [ str ] = None ,
max_loras : Optional [ int ] = None ,
output_len : Optional [ int ] = None ,
enable_multimodal_chat : bool = False ,
skip_chat_template : bool = False ,
2025-08-19 04:32:18 -04:00
request_id_prefix : str = " " ,
2025-05-31 15:07:38 -04:00
* * kwargs ,
) - > list :
sampled_requests = [ ]
2025-08-19 04:32:18 -04:00
for i , item in enumerate ( self . data ) :
2025-05-31 15:07:38 -04:00
if len ( sampled_requests ) > = num_requests :
break
prompt = item [ " prompt " ]
# apply template
if not skip_chat_template :
prompt = tokenizer . apply_chat_template (
[ { " role " : " user " , " content " : prompt } ] ,
add_generation_prompt = True ,
tokenize = False ,
)
prompt_len = len ( tokenizer ( prompt ) . input_ids )
sampled_requests . append (
SampleRequest (
prompt = prompt ,
prompt_len = prompt_len ,
expected_output_len = output_len ,
2025-08-19 04:32:18 -04:00
request_id = request_id_prefix + str ( i ) ,
2025-05-31 15:07:38 -04:00
)
)
2025-08-19 04:32:18 -04:00
self . maybe_oversample_requests (
sampled_requests , num_requests , request_id_prefix
)
2025-05-31 15:07:38 -04:00
return sampled_requests
2025-03-10 00:23:11 -07:00
# -----------------------------------------------------------------------------
# Sonnet Dataset Implementation
# -----------------------------------------------------------------------------
class SonnetDataset ( BenchmarkDataset ) :
"""
Simplified implementation of the Sonnet dataset . Loads poem lines from a
text file and generates sample requests . Default values here copied from
` benchmark_serving . py ` for the sonnet dataset .
"""
DEFAULT_PREFIX_LEN = 200
DEFAULT_INPUT_LEN = 550
DEFAULT_OUTPUT_LEN = 150
def __init__ (
self ,
* * kwargs ,
) - > None :
super ( ) . __init__ ( * * kwargs )
self . load_data ( )
def load_data ( self ) - > None :
if not self . dataset_path :
raise ValueError ( " dataset_path must be provided. " )
with open ( self . dataset_path , encoding = " utf-8 " ) as f :
self . data = f . readlines ( )
2025-03-19 21:32:58 -07:00
def sample (
self ,
tokenizer ,
num_requests : int ,
prefix_len : int = DEFAULT_PREFIX_LEN ,
input_len : int = DEFAULT_INPUT_LEN ,
output_len : int = DEFAULT_OUTPUT_LEN ,
return_prompt_formatted : bool = False ,
2025-08-19 04:32:18 -04:00
request_id_prefix : str = " " ,
2025-03-19 21:32:58 -07:00
* * kwargs ,
) - > list :
2025-03-10 00:23:11 -07:00
# Calculate average token length for a poem line.
tokenized_lines = [ tokenizer ( line ) . input_ids for line in self . data ]
2025-05-13 14:43:29 +01:00
avg_len = sum ( len ( tokens ) for tokens in tokenized_lines ) / len ( tokenized_lines )
2025-03-10 00:23:11 -07:00
# Build the base prompt.
base_prompt = " Pick as many lines as you can from these poem lines: \n "
base_msg = [ { " role " : " user " , " content " : base_prompt } ]
2025-05-13 14:43:29 +01:00
base_fmt = tokenizer . apply_chat_template (
base_msg , add_generation_prompt = True , tokenize = False
)
2025-03-10 00:23:11 -07:00
base_offset = len ( tokenizer ( base_fmt ) . input_ids )
if input_len < = base_offset :
raise ValueError (
f " ' input_len ' must be higher than the base prompt length "
2025-05-13 14:43:29 +01:00
f " ( { base_offset } ). "
)
2025-03-10 00:23:11 -07:00
# Determine how many poem lines to use.
num_input_lines = round ( ( input_len - base_offset ) / avg_len )
2025-04-09 22:35:07 -07:00
num_prefix_lines = max ( round ( ( prefix_len - base_offset ) / avg_len ) , 0 )
2025-03-10 00:23:11 -07:00
prefix_lines = self . data [ : num_prefix_lines ]
samples = [ ]
2025-08-19 04:32:18 -04:00
ind = 0
2025-04-11 03:15:06 +01:00
while len ( samples ) < num_requests :
2025-05-13 14:43:29 +01:00
extra_lines = random . choices (
self . data , k = num_input_lines - num_prefix_lines
)
2025-03-10 00:23:11 -07:00
prompt = f " { base_prompt } { ' ' . join ( prefix_lines + extra_lines ) } "
msg = [ { " role " : " user " , " content " : prompt } ]
prompt_formatted = tokenizer . apply_chat_template (
2025-05-13 14:43:29 +01:00
msg , add_generation_prompt = True , tokenize = False
)
2025-03-10 00:23:11 -07:00
prompt_len = len ( tokenizer ( prompt_formatted ) . input_ids )
2025-08-19 04:32:18 -04:00
2025-04-11 03:15:06 +01:00
if prompt_len < = input_len :
samples . append (
SampleRequest (
2025-05-13 14:43:29 +01:00
prompt = prompt_formatted if return_prompt_formatted else prompt ,
2025-04-11 03:15:06 +01:00
prompt_len = prompt_len ,
expected_output_len = output_len ,
2025-08-19 04:32:18 -04:00
request_id = request_id_prefix + str ( ind ) ,
2025-05-13 14:43:29 +01:00
)
)
2025-08-19 04:32:18 -04:00
ind + = 1
2025-03-10 00:23:11 -07:00
return samples
# -----------------------------------------------------------------------------
# BurstGPT Dataset Implementation
# -----------------------------------------------------------------------------
class BurstGPTDataset ( BenchmarkDataset ) :
"""
Implements the BurstGPT dataset . Loads data from a CSV file and generates
sample requests based on synthetic prompt generation . Only rows with Model
" GPT-4 " and positive response tokens are used .
"""
def __init__ ( self , * * kwargs ) - > None :
super ( ) . __init__ ( * * kwargs )
self . load_data ( )
2025-05-13 14:43:29 +01:00
def load_data (
self ,
) :
2025-03-10 00:23:11 -07:00
if self . dataset_path is None :
raise ValueError ( " dataset_path must be provided for loading data. " )
df = pd . read_csv ( self . dataset_path )
# Filter to keep only GPT-4 rows.
gpt4_df = df [ df [ " Model " ] == " GPT-4 " ]
# Remove failed requests (where Response tokens is 0 or less).
gpt4_df = gpt4_df [ gpt4_df [ " Response tokens " ] > 0 ]
# Sample the desired number of rows.
self . data = gpt4_df
def _sample_loaded_data ( self , num_requests : int ) - > list :
if num_requests < = len ( self . data ) :
2025-05-13 14:43:29 +01:00
data = self . data . sample ( n = num_requests , random_state = self . random_seed )
2025-03-10 00:23:11 -07:00
else :
data = self . data . sample (
n = num_requests ,
random_state = self . random_seed ,
replace = True ,
)
# Convert the dataframe to a list of lists.
return data . values . tolist ( )
2025-03-19 21:32:58 -07:00
def sample (
self ,
tokenizer : PreTrainedTokenizerBase ,
num_requests : int ,
max_loras : Optional [ int ] = None ,
lora_path : Optional [ str ] = None ,
2025-08-19 04:32:18 -04:00
request_id_prefix : str = " " ,
2025-03-19 21:32:58 -07:00
* * kwargs ,
) - > list [ SampleRequest ] :
2025-03-10 00:23:11 -07:00
samples = [ ]
data = self . _sample_loaded_data ( num_requests = num_requests )
for i in range ( num_requests ) :
input_len = int ( data [ i ] [ 2 ] )
output_len = int ( data [ i ] [ 3 ] )
lora_req , tokenizer = self . get_random_lora_request (
2025-05-13 14:43:29 +01:00
tokenizer = tokenizer , max_loras = max_loras , lora_path = lora_path
)
2025-03-10 00:23:11 -07:00
vocab_size = tokenizer . vocab_size
# Generate a synthetic prompt: a list of token IDs computed as (i +
# j) modulo vocab_size.
token_ids = [ ( i + j ) % vocab_size for j in range ( input_len ) ]
prompt = tokenizer . decode ( token_ids )
samples . append (
SampleRequest (
prompt = prompt ,
prompt_len = input_len ,
expected_output_len = output_len ,
lora_request = lora_req ,
2025-08-19 04:32:18 -04:00
request_id = request_id_prefix + str ( i ) ,
2025-05-13 14:43:29 +01:00
)
)
2025-03-10 00:23:11 -07:00
return samples
# -----------------------------------------------------------------------------
2025-03-31 00:38:58 -07:00
# HuggingFace Dataset Base Implementation
2025-03-10 00:23:11 -07:00
# -----------------------------------------------------------------------------
class HuggingFaceDataset ( BenchmarkDataset ) :
2025-03-31 00:38:58 -07:00
""" Base class for datasets hosted on HuggingFace. """
SUPPORTED_DATASET_PATHS : Union [ set [ str ] , dict [ str , Callable ] ] = set ( )
2025-03-10 00:23:11 -07:00
def __init__ (
self ,
2025-03-31 00:38:58 -07:00
dataset_path : str ,
2025-03-10 00:23:11 -07:00
dataset_split : str ,
2025-07-09 21:35:16 +08:00
no_stream : bool = False ,
2025-03-10 00:23:11 -07:00
dataset_subset : Optional [ str ] = None ,
* * kwargs ,
) - > None :
2025-03-31 00:38:58 -07:00
super ( ) . __init__ ( dataset_path = dataset_path , * * kwargs )
2025-03-10 00:23:11 -07:00
self . dataset_split = dataset_split
self . dataset_subset = dataset_subset
2025-07-09 21:35:16 +08:00
self . load_stream = not no_stream
2025-03-10 00:23:11 -07:00
self . load_data ( )
def load_data ( self ) - > None :
2025-03-31 00:38:58 -07:00
""" Load data from HuggingFace datasets. """
2025-03-10 00:23:11 -07:00
self . data = load_dataset (
self . dataset_path ,
name = self . dataset_subset ,
split = self . dataset_split ,
2025-07-09 21:35:16 +08:00
streaming = self . load_stream ,
2025-03-10 00:23:11 -07:00
)
2025-03-31 00:38:58 -07:00
self . data = self . data . shuffle ( seed = self . random_seed )
# -----------------------------------------------------------------------------
# Conversation Dataset Implementation
# -----------------------------------------------------------------------------
class ConversationDataset ( HuggingFaceDataset ) :
""" Dataset for conversation data with multimodal support. """
2025-05-13 14:43:29 +01:00
2025-03-31 00:38:58 -07:00
SUPPORTED_DATASET_PATHS = {
2025-05-13 14:43:29 +01:00
" lmms-lab/LLaVA-OneVision-Data " ,
" Aeala/ShareGPT_Vicuna_unfiltered " ,
2025-03-31 00:38:58 -07:00
}
2025-04-19 11:24:14 +02:00
IS_MULTIMODAL = True
2025-03-10 00:23:11 -07:00
2025-05-13 14:43:29 +01:00
def sample (
self ,
tokenizer : PreTrainedTokenizerBase ,
num_requests : int ,
output_len : Optional [ int ] = None ,
enable_multimodal_chat : bool = False ,
2025-08-19 04:32:18 -04:00
request_id_prefix : str = " " ,
2025-05-13 14:43:29 +01:00
* * kwargs ,
) - > list :
2025-03-31 00:38:58 -07:00
# Filter examples with at least 2 conversations
2025-05-13 14:43:29 +01:00
filtered_data = self . data . filter ( lambda x : len ( x [ " conversations " ] ) > = 2 )
2025-03-10 00:23:11 -07:00
sampled_requests = [ ]
dynamic_output = output_len is None
2025-08-19 04:32:18 -04:00
ind = 0
2025-03-10 00:23:11 -07:00
2025-03-31 00:38:58 -07:00
for item in filtered_data :
2025-03-10 00:23:11 -07:00
if len ( sampled_requests ) > = num_requests :
break
conv = item [ " conversations " ]
prompt , completion = conv [ 0 ] [ " value " ] , conv [ 1 ] [ " value " ]
prompt_ids = tokenizer ( prompt ) . input_ids
completion_ids = tokenizer ( completion ) . input_ids
prompt_len = len ( prompt_ids )
completion_len = len ( completion_ids )
output_len = completion_len if dynamic_output else output_len
assert isinstance ( output_len , int ) and output_len > 0
2025-05-13 14:43:29 +01:00
if dynamic_output and not is_valid_sequence ( prompt_len , completion_len ) :
2025-03-10 00:23:11 -07:00
continue
2025-05-13 14:43:29 +01:00
mm_content = process_image ( item [ " image " ] ) if " image " in item else None
2025-03-13 21:07:54 -07:00
if enable_multimodal_chat :
# Note: when chat is enabled the request prompt_len is no longer
# accurate and we will be using request output to count the
# actual prompt len and output len
2025-05-13 14:43:29 +01:00
prompt = self . apply_multimodal_chat_transformation ( prompt , mm_content )
2025-03-10 00:23:11 -07:00
sampled_requests . append (
SampleRequest (
prompt = prompt ,
prompt_len = prompt_len ,
expected_output_len = output_len ,
multi_modal_data = mm_content ,
2025-08-19 04:32:18 -04:00
request_id = request_id_prefix + str ( ind ) ,
2025-05-13 14:43:29 +01:00
)
)
2025-08-19 04:32:18 -04:00
ind + = 1
self . maybe_oversample_requests (
sampled_requests , num_requests , request_id_prefix
)
2025-03-10 00:23:11 -07:00
return sampled_requests
# -----------------------------------------------------------------------------
# Vision Arena Dataset Implementation
# -----------------------------------------------------------------------------
2025-03-13 21:07:54 -07:00
class VisionArenaDataset ( HuggingFaceDataset ) :
2025-03-10 00:23:11 -07:00
"""
Vision Arena Dataset .
"""
DEFAULT_OUTPUT_LEN = 128
2025-03-31 00:38:58 -07:00
SUPPORTED_DATASET_PATHS = {
2025-05-13 14:43:29 +01:00
" lmarena-ai/VisionArena-Chat " : lambda x : x [ " conversation " ] [ 0 ] [ 0 ] [ " content " ] ,
" lmarena-ai/vision-arena-bench-v0.1 " : lambda x : x [ " turns " ] [ 0 ] [ 0 ] [ " content " ] ,
2025-03-31 00:38:58 -07:00
}
2025-04-19 11:24:14 +02:00
IS_MULTIMODAL = True
2025-03-10 00:23:11 -07:00
2025-03-19 21:32:58 -07:00
def sample (
self ,
tokenizer : PreTrainedTokenizerBase ,
num_requests : int ,
output_len : Optional [ int ] = None ,
enable_multimodal_chat : bool = False ,
2025-08-19 04:32:18 -04:00
request_id_prefix : str = " " ,
2025-03-19 21:32:58 -07:00
* * kwargs ,
) - > list :
2025-05-13 14:43:29 +01:00
output_len = output_len if output_len is not None else self . DEFAULT_OUTPUT_LEN
2025-03-10 00:23:11 -07:00
sampled_requests = [ ]
2025-08-19 04:32:18 -04:00
for i , item in enumerate ( self . data ) :
2025-03-10 00:23:11 -07:00
if len ( sampled_requests ) > = num_requests :
break
2025-03-31 00:38:58 -07:00
parser_fn = self . SUPPORTED_DATASET_PATHS . get ( self . dataset_path )
if parser_fn is None :
2025-05-13 14:43:29 +01:00
raise ValueError ( f " Unsupported dataset path: { self . dataset_path } " )
2025-03-31 00:38:58 -07:00
prompt = parser_fn ( item )
2025-03-10 00:23:11 -07:00
mm_content = process_image ( item [ " images " ] [ 0 ] )
2025-03-13 21:07:54 -07:00
prompt_len = len ( tokenizer ( prompt ) . input_ids )
if enable_multimodal_chat :
# Note: when chat is enabled the request prompt_len is no longer
# accurate and we will be using request output to count the
# actual prompt len
2025-05-13 14:43:29 +01:00
prompt = self . apply_multimodal_chat_transformation ( prompt , mm_content )
2025-03-10 00:23:11 -07:00
sampled_requests . append (
SampleRequest (
prompt = prompt ,
prompt_len = prompt_len ,
expected_output_len = output_len ,
multi_modal_data = mm_content ,
2025-08-19 04:32:18 -04:00
request_id = request_id_prefix + str ( i ) ,
2025-05-13 14:43:29 +01:00
)
)
2025-08-19 04:32:18 -04:00
self . maybe_oversample_requests (
sampled_requests , num_requests , request_id_prefix
)
2025-03-10 00:23:11 -07:00
return sampled_requests
2025-03-27 19:47:05 -07:00
# -----------------------------------------------------------------------------
# Instruct Coder Dataset Implementation
# -----------------------------------------------------------------------------
class InstructCoderDataset ( HuggingFaceDataset ) :
"""
InstructCoder Dataset .
https : / / huggingface . co / datasets / likaixin / InstructCoder
2025-03-31 00:38:58 -07:00
InstructCoder is the dataset designed for general code editing . It consists
of 114 , 239 instruction - input - output triplets , and covers multiple distinct
code editing scenario .
2025-03-27 19:47:05 -07:00
"""
DEFAULT_OUTPUT_LEN = 200 # this is the average default output length
2025-03-31 00:38:58 -07:00
SUPPORTED_DATASET_PATHS = {
" likaixin/InstructCoder " ,
}
2025-03-27 19:47:05 -07:00
2025-05-13 14:43:29 +01:00
def sample (
self ,
tokenizer : PreTrainedTokenizerBase ,
num_requests : int ,
output_len : Optional [ int ] = None ,
enable_multimodal_chat : bool = False ,
2025-08-19 04:32:18 -04:00
request_id_prefix : str = " " ,
2025-05-13 14:43:29 +01:00
* * kwargs ,
) - > list :
output_len = output_len if output_len is not None else self . DEFAULT_OUTPUT_LEN
2025-03-27 19:47:05 -07:00
sampled_requests = [ ]
2025-08-19 04:32:18 -04:00
for i , item in enumerate ( self . data ) :
2025-03-27 19:47:05 -07:00
if len ( sampled_requests ) > = num_requests :
break
2025-08-21 13:03:00 +08:00
prompt = (
f " { item [ ' input ' ] } \n \n { item [ ' instruction ' ] } Just output "
" the code, do not include any explanation. "
)
2025-06-03 18:26:33 -04:00
# apply template
prompt = tokenizer . apply_chat_template (
[ { " role " : " user " , " content " : prompt } ] ,
add_generation_prompt = True ,
tokenize = False ,
)
2025-03-27 19:47:05 -07:00
prompt_len = len ( tokenizer ( prompt ) . input_ids )
sampled_requests . append (
SampleRequest (
prompt = prompt ,
prompt_len = prompt_len ,
expected_output_len = output_len ,
2025-08-19 04:32:18 -04:00
request_id = request_id_prefix + str ( i ) ,
2025-05-13 14:43:29 +01:00
)
)
2025-08-19 04:32:18 -04:00
self . maybe_oversample_requests (
sampled_requests , num_requests , request_id_prefix
)
2025-03-27 19:47:05 -07:00
return sampled_requests
2025-04-02 23:09:18 -07:00
2025-04-28 19:46:15 -04:00
# -----------------------------------------------------------------------------
# MT-Bench Dataset Implementation
# -----------------------------------------------------------------------------
class MTBenchDataset ( HuggingFaceDataset ) :
"""
MT - Bench Dataset .
https : / / huggingface . co / datasets / philschmid / mt - bench
2025-05-13 14:43:29 +01:00
We create a single turn dataset for MT - Bench .
2025-04-28 19:46:15 -04:00
This is similar to Spec decoding benchmark setup in vLLM
https : / / github . com / vllm - project / vllm / blob / 9 d98ab5ec / examples / offline_inference / eagle . py #L14-L18
2025-05-13 14:43:29 +01:00
""" # noqa: E501
2025-04-28 19:46:15 -04:00
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
SUPPORTED_DATASET_PATHS = {
" philschmid/mt-bench " ,
}
2025-05-13 14:43:29 +01:00
def sample (
self ,
tokenizer : PreTrainedTokenizerBase ,
num_requests : int ,
output_len : Optional [ int ] = None ,
enable_multimodal_chat : bool = False ,
2025-08-19 04:32:18 -04:00
request_id_prefix : str = " " ,
2025-05-13 14:43:29 +01:00
* * kwargs ,
) - > list :
output_len = output_len if output_len is not None else self . DEFAULT_OUTPUT_LEN
2025-04-28 19:46:15 -04:00
sampled_requests = [ ]
2025-08-19 04:32:18 -04:00
for i , item in enumerate ( self . data ) :
2025-04-28 19:46:15 -04:00
if len ( sampled_requests ) > = num_requests :
break
2025-05-13 14:43:29 +01:00
prompt = item [ " turns " ] [ 0 ]
2025-04-28 19:46:15 -04:00
# apply template
2025-05-13 14:43:29 +01:00
prompt = tokenizer . apply_chat_template (
[ { " role " : " user " , " content " : prompt } ] ,
add_generation_prompt = True ,
tokenize = False ,
)
2025-04-28 19:46:15 -04:00
prompt_len = len ( tokenizer ( prompt ) . input_ids )
sampled_requests . append (
SampleRequest (
prompt = prompt ,
prompt_len = prompt_len ,
expected_output_len = output_len ,
2025-08-19 04:32:18 -04:00
request_id = request_id_prefix + str ( i ) ,
2025-05-13 14:43:29 +01:00
)
)
2025-08-19 04:32:18 -04:00
self . maybe_oversample_requests (
sampled_requests , num_requests , request_id_prefix
)
2025-04-28 19:46:15 -04:00
return sampled_requests
2025-04-02 23:09:18 -07:00
# -----------------------------------------------------------------------------
# AIMO Dataset Implementation
# -----------------------------------------------------------------------------
class AIMODataset ( HuggingFaceDataset ) :
"""
Dataset class for processing a AIMO dataset with reasoning questions .
"""
2025-05-13 14:43:29 +01:00
2025-04-02 23:09:18 -07:00
SUPPORTED_DATASET_PATHS = {
2025-05-13 14:43:29 +01:00
" AI-MO/aimo-validation-aime " ,
" AI-MO/NuminaMath-1.5 " ,
" AI-MO/NuminaMath-CoT " ,
2025-04-02 23:09:18 -07:00
}
2025-05-13 14:43:29 +01:00
def sample (
self ,
tokenizer : PreTrainedTokenizerBase ,
num_requests : int ,
output_len : Optional [ int ] = None ,
2025-08-19 04:32:18 -04:00
request_id_prefix : str = " " ,
2025-05-13 14:43:29 +01:00
* * kwargs ,
) - > list :
2025-04-02 23:09:18 -07:00
sampled_requests = [ ]
dynamic_output = output_len is None
2025-08-19 04:32:18 -04:00
ind = 0
2025-04-02 23:09:18 -07:00
for item in self . data :
if len ( sampled_requests ) > = num_requests :
break
2025-05-13 14:43:29 +01:00
prompt , completion = item [ " problem " ] , item [ " solution " ]
2025-04-02 23:09:18 -07:00
prompt_ids = tokenizer ( prompt ) . input_ids
completion_ids = tokenizer ( completion ) . input_ids
prompt_len = len ( prompt_ids )
completion_len = len ( completion_ids )
output_len = completion_len if dynamic_output else output_len
assert isinstance ( output_len , int ) and output_len > 0
2025-05-13 14:43:29 +01:00
if dynamic_output and not is_valid_sequence (
prompt_len , completion_len , max_prompt_len = 2048 , max_total_len = 32000
) :
2025-04-02 23:09:18 -07:00
continue
sampled_requests . append (
SampleRequest (
prompt = prompt ,
prompt_len = prompt_len ,
expected_output_len = output_len ,
multi_modal_data = None ,
2025-08-19 04:32:18 -04:00
request_id = request_id_prefix + str ( ind ) ,
2025-05-13 14:43:29 +01:00
)
)
2025-08-19 04:32:18 -04:00
ind + = 1
self . maybe_oversample_requests (
sampled_requests , num_requests , request_id_prefix
)
2025-04-02 23:09:18 -07:00
return sampled_requests
2025-04-19 11:24:14 +02:00
2025-05-06 21:38:45 +02:00
# -----------------------------------------------------------------------------
# Next Edit Prediction Dataset Implementation
# -----------------------------------------------------------------------------
zeta_prompt = """ ### Instruction:
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides , suggesting the appropriate edits within the excerpt , taking into account the cursor location .
### User Edits:
{ }
### User Excerpt:
{ }
### Response:
2025-05-13 14:43:29 +01:00
""" # noqa: E501
2025-05-06 21:38:45 +02:00
def _format_zeta_prompt (
2025-05-13 14:43:29 +01:00
sample : dict , original_start_marker : str = " <|editable_region_start|> "
) - > dict :
2025-05-06 21:38:45 +02:00
""" Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
2025-05-13 14:43:29 +01:00
This function formats examples from the NEP dataset
into prompts and expected outputs . It could be
2025-05-06 21:38:45 +02:00
further extended to support more NEP datasets .
2025-05-13 14:43:29 +01:00
2025-05-06 21:38:45 +02:00
Args :
2025-05-13 14:43:29 +01:00
sample : The dataset sample containing events ,
2025-05-06 21:38:45 +02:00
inputs , and outputs .
2025-05-13 14:43:29 +01:00
original_start_marker : The marker indicating the
start of the editable region . Defaults to
2025-05-06 21:38:45 +02:00
" <|editable_region_start|> " .
2025-05-13 14:43:29 +01:00
2025-05-06 21:38:45 +02:00
Returns :
A dictionary with the formatted prompts and expected outputs .
"""
events = sample [ " events " ]
input = sample [ " input " ]
output = sample [ " output " ]
prompt = zeta_prompt . format ( events , input )
# following the original implementation, extract the focused region
# from the raw output
output_start_index = output . find ( original_start_marker )
output_focused_region = output [ output_start_index : ]
expected_output = output_focused_region
return { " prompt " : prompt , " expected_output " : expected_output }
class NextEditPredictionDataset ( HuggingFaceDataset ) :
"""
Dataset class for processing a Next Edit Prediction dataset .
"""
SUPPORTED_DATASET_PATHS = {
" zed-industries/zeta " ,
}
MAPPING_PROMPT_FUNCS = {
" zed-industries/zeta " : _format_zeta_prompt ,
}
2025-08-19 04:32:18 -04:00
def sample (
self ,
tokenizer : PreTrainedTokenizerBase ,
num_requests : int ,
request_id_prefix : str = " " ,
* * kwargs ,
) :
2025-05-13 14:43:29 +01:00
formatting_prompt_func = self . MAPPING_PROMPT_FUNCS . get ( self . dataset_path )
2025-05-06 21:38:45 +02:00
if formatting_prompt_func is None :
raise ValueError ( f " Unsupported dataset path: { self . dataset_path } " )
samples = [ ]
2025-08-19 04:32:18 -04:00
for i , sample in enumerate ( self . data ) :
2025-05-06 21:38:45 +02:00
sample = formatting_prompt_func ( sample )
samples . append (
SampleRequest (
prompt = sample [ " prompt " ] ,
prompt_len = len ( tokenizer ( sample [ " prompt " ] ) . input_ids ) ,
expected_output_len = len (
2025-05-13 14:43:29 +01:00
tokenizer ( sample [ " expected_output " ] ) . input_ids
) ,
2025-08-19 04:32:18 -04:00
request_id = request_id_prefix + str ( i ) ,
2025-05-13 14:43:29 +01:00
)
)
2025-05-06 21:38:45 +02:00
if len ( samples ) > = num_requests :
break
2025-08-19 04:32:18 -04:00
self . maybe_oversample_requests ( samples , num_requests , request_id_prefix )
2025-05-06 21:38:45 +02:00
return samples
2025-04-19 11:24:14 +02:00
# -----------------------------------------------------------------------------
# ASR Dataset Implementation
# -----------------------------------------------------------------------------
class ASRDataset ( HuggingFaceDataset ) :
"""
Dataset class for processing a ASR dataset for transcription .
Tested on the following set :
+ - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
| Dataset | Domain | Speaking Style | hf - subset |
+ - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
| TED - LIUM | TED talks | Oratory | release1 , release2 , release3 |
| | | | release3 - speaker - adaptation |
| VoxPopuli | European Parliament | Oratory | en , de , it , fr , . . . |
| LibriSpeech | Audiobook | Narrated | " LIUM/tedlium " |
| GigaSpeech | Audiobook , podcast , YouTube | Narrated , spontaneous | xs , s , m , l , xl , dev , test |
| SPGISpeech | Financial meetings | Oratory , spontaneous | S , M , L , dev , test |
| AMI | Meetings | Spontaneous | ihm , sdm |
+ - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
2025-05-13 14:43:29 +01:00
""" # noqa: E501
2025-04-19 11:24:14 +02:00
SUPPORTED_DATASET_PATHS = {
2025-05-13 14:43:29 +01:00
" openslr/librispeech_asr " ,
" facebook/voxpopuli " ,
" LIUM/tedlium " ,
" edinburghcstr/ami " ,
" speechcolab/gigaspeech " ,
" kensho/spgispeech " ,
2025-04-19 11:24:14 +02:00
}
DEFAULT_OUTPUT_LEN = 128
IS_MULTIMODAL = True
# TODO Whisper-specific. Abstract interface when more models are supported.
2025-05-13 14:43:29 +01:00
TRANSCRIPTION_PREAMBLE = " <|startoftranscript|><|en|><|transcribe|><|notimestamps|> "
2025-04-19 11:24:14 +02:00
skip_long_audios : bool = True
def sample (
self ,
tokenizer : PreTrainedTokenizerBase ,
num_requests : int ,
output_len : Optional [ int ] = None ,
2025-08-19 04:32:18 -04:00
request_id_prefix : str = " " ,
2025-04-19 11:24:14 +02:00
* * kwargs ,
) - > list :
import librosa
2025-05-13 14:43:29 +01:00
output_len = output_len if output_len is not None else self . DEFAULT_OUTPUT_LEN
2025-04-19 11:24:14 +02:00
prompt = ASRDataset . TRANSCRIPTION_PREAMBLE
prompt_len = len ( tokenizer ( prompt ) . input_ids )
sampled_requests = [ ]
skipped = 0
2025-08-19 04:32:18 -04:00
ind = 0
2025-04-19 11:24:14 +02:00
for item in self . data :
if len ( sampled_requests ) > = num_requests :
break
audio = item [ " audio " ]
y , sr = audio [ " array " ] , audio [ " sampling_rate " ]
duration_s = librosa . get_duration ( y = y , sr = sr )
# Whisper max supported duration
if self . skip_long_audios and duration_s > 30 :
skipped + = 1
continue
mm_content = { " audio " : ( y , sr ) }
sampled_requests . append (
SampleRequest (
prompt = prompt ,
prompt_len = prompt_len ,
expected_output_len = output_len ,
multi_modal_data = mm_content ,
2025-08-19 04:32:18 -04:00
request_id = request_id_prefix + str ( ind ) ,
2025-05-13 14:43:29 +01:00
)
)
2025-08-19 04:32:18 -04:00
ind + = 1
2025-04-19 11:24:14 +02:00
if skipped :
2025-05-13 14:43:29 +01:00
logger . warning (
" %d samples discarded from dataset due to "
" their length being greater than "
" what Whisper supports. " ,
skipped ,
)
2025-08-19 04:32:18 -04:00
self . maybe_oversample_requests (
sampled_requests , num_requests , request_id_prefix
)
2025-04-19 11:24:14 +02:00
return sampled_requests