diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 86e080b55..36573a040 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -380,7 +380,7 @@ def gen_prompt_decode_to_target_len( max_retry: int = 10, add_special_tokens: bool = False, rng: np.random.Generator | None = None, -) -> tuple[str, list[int]]: +) -> tuple[str, list[int], int]: """ Ensure decoded-then-encoded prompt length matches the target token length. @@ -392,7 +392,9 @@ def gen_prompt_decode_to_target_len( [6880, 6881] -> ['Ġcalls', 'here'] -> [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] - Returns a tuple of the final prompt string and the adjusted token sequence. + Returns a tuple of the final prompt string, the adjusted token sequence, + and the token mismatch (final_len - target_token_len) if the retry budget + is exhausted. """ remain_num_try = max_retry token_mismatch = 0 @@ -499,7 +501,7 @@ class RandomDataset(BenchmarkDataset): allowed_tokens = np.array(list(set(all_tokens) - set(prohibited_tokens))) # Generate prefix once - prefix_token_ids = self.get_prefix(allowed_tokens, prefix_len) + prefix_token_ids = self.get_prefix(tokenizer, allowed_tokens, prefix_len) requests = [] token_mismatch_total = 0 @@ -554,19 +556,36 @@ class RandomDataset(BenchmarkDataset): def get_prefix( self, + tokenizer: TokenizerLike, allowed_tokens: np.ndarray, prefix_len: int, ) -> list[int]: """ Get the prefix for the dataset. """ - return ( - allowed_tokens[ - self._rng.integers(0, len(allowed_tokens), size=prefix_len) - ].tolist() - if prefix_len > 0 - else [] + if prefix_len <= 0: + return [] + + prefix_tokens = allowed_tokens[ + self._rng.integers(0, len(allowed_tokens), size=prefix_len) + ].tolist() + _, adjusted_tokens, token_mismatch = gen_prompt_decode_to_target_len( + tokenizer=tokenizer, + token_sequence=prefix_tokens, + target_token_len=prefix_len, + add_special_tokens=False, + rng=self._rng, ) + if token_mismatch != 0: + sign = "more" if token_mismatch > 0 else "fewer" + logger.warning( + "Prefix tokenization produced %d %s tokens than expected " + "after decoding and re-encoding. This is expected due to " + "the imperfect nature of the sampling procedure", + abs(token_mismatch), + sign, + ) + return adjusted_tokens def get_sampling_params( self, @@ -1128,7 +1147,7 @@ class RandomMultiModalDataset(RandomDataset): "Sampling from %d out of %d (vocab size)", len(allowed_tokens), vocab_size ) # Generate prefix once - prefix_token_ids = self.get_prefix(allowed_tokens, prefix_len) + prefix_token_ids = self.get_prefix(tokenizer, allowed_tokens, prefix_len) # Add synthetic multimodal items to each request mm_requests = [] token_mismatch_total = 0