[Bugfix] Fix Random Dataset Prefix Length Inaccuracy (#33907)

Signed-off-by: frankwang28 <frank.wbb@hotmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Frank Wang
2026-02-12 18:21:19 -08:00
committed by GitHub
parent de13dd781f
commit b86bf4417e

View File

@@ -380,7 +380,7 @@ def gen_prompt_decode_to_target_len(
max_retry: int = 10,
add_special_tokens: bool = False,
rng: np.random.Generator | None = None,
) -> tuple[str, list[int]]:
) -> tuple[str, list[int], int]:
"""
Ensure decoded-then-encoded prompt length matches the target token length.
@@ -392,7 +392,9 @@ def gen_prompt_decode_to_target_len(
[6880, 6881] -> ['Ġcalls', 'here'] ->
[1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
Returns a tuple of the final prompt string and the adjusted token sequence.
Returns a tuple of the final prompt string, the adjusted token sequence,
and the token mismatch (final_len - target_token_len) if the retry budget
is exhausted.
"""
remain_num_try = max_retry
token_mismatch = 0
@@ -499,7 +501,7 @@ class RandomDataset(BenchmarkDataset):
allowed_tokens = np.array(list(set(all_tokens) - set(prohibited_tokens)))
# Generate prefix once
prefix_token_ids = self.get_prefix(allowed_tokens, prefix_len)
prefix_token_ids = self.get_prefix(tokenizer, allowed_tokens, prefix_len)
requests = []
token_mismatch_total = 0
@@ -554,19 +556,36 @@ class RandomDataset(BenchmarkDataset):
def get_prefix(
self,
tokenizer: TokenizerLike,
allowed_tokens: np.ndarray,
prefix_len: int,
) -> list[int]:
"""
Get the prefix for the dataset.
"""
return (
allowed_tokens[
self._rng.integers(0, len(allowed_tokens), size=prefix_len)
].tolist()
if prefix_len > 0
else []
if prefix_len <= 0:
return []
prefix_tokens = allowed_tokens[
self._rng.integers(0, len(allowed_tokens), size=prefix_len)
].tolist()
_, adjusted_tokens, token_mismatch = gen_prompt_decode_to_target_len(
tokenizer=tokenizer,
token_sequence=prefix_tokens,
target_token_len=prefix_len,
add_special_tokens=False,
rng=self._rng,
)
if token_mismatch != 0:
sign = "more" if token_mismatch > 0 else "fewer"
logger.warning(
"Prefix tokenization produced %d %s tokens than expected "
"after decoding and re-encoding. This is expected due to "
"the imperfect nature of the sampling procedure",
abs(token_mismatch),
sign,
)
return adjusted_tokens
def get_sampling_params(
self,
@@ -1128,7 +1147,7 @@ class RandomMultiModalDataset(RandomDataset):
"Sampling from %d out of %d (vocab size)", len(allowed_tokens), vocab_size
)
# Generate prefix once
prefix_token_ids = self.get_prefix(allowed_tokens, prefix_len)
prefix_token_ids = self.get_prefix(tokenizer, allowed_tokens, prefix_len)
# Add synthetic multimodal items to each request
mm_requests = []
token_mismatch_total = 0