[BugFix] Fix chunked prompt logprobs + preemption (#29071)

This commit is contained in:
Nick Hill
2025-11-22 13:07:18 -08:00
committed by GitHub
parent eb5352a770
commit 7df331c66b
6 changed files with 127 additions and 31 deletions

View File

@@ -219,9 +219,6 @@ class InputBatch:
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
@@ -385,12 +382,6 @@ class InputBatch:
if sampling_params.logprobs == -1
else sampling_params.logprobs
)
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = (
self.vocab_size
if sampling_params.prompt_logprobs == -1
else sampling_params.prompt_logprobs
)
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
@@ -488,7 +479,6 @@ class InputBatch:
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
self.has_allowed_token_ids.discard(req_id)
@@ -972,10 +962,6 @@ class InputBatch:
def max_num_logprobs(self) -> int | None:
return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_prompt_logprob(self) -> bool:
return not self.num_prompt_logprobs
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0

View File

@@ -393,6 +393,9 @@ class GPUModelRunner(
# Request states.
self.requests: dict[str, CachedRequestState] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
self.comm_stream = torch.cuda.Stream()
# Input Batch
@@ -687,6 +690,7 @@ class GPUModelRunner(
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
@@ -755,6 +759,13 @@ class GPUModelRunner(
)
self.requests[req_id] = req_state
if sampling_params and sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = (
self.input_batch.vocab_size
if sampling_params.prompt_logprobs == -1
else sampling_params.prompt_logprobs
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._init_mrope_positions(req_state)
@@ -2671,7 +2682,7 @@ class GPUModelRunner(
scheduler_output, self.vllm_config
)
if self.cache_config.kv_sharing_fast_prefill:
assert not self.input_batch.num_prompt_logprobs, (
assert not self.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect "
"logprobs for prompt tokens, tokens, please disable "
"it when the requests need prompt logprobs"
@@ -3436,7 +3447,7 @@ class GPUModelRunner(
hidden_states: torch.Tensor,
num_scheduled_tokens: dict[str, int],
) -> dict[str, LogprobsTensors | None]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
num_prompt_logprobs_dict = self.num_prompt_logprobs
if not num_prompt_logprobs_dict:
return {}
@@ -3447,7 +3458,10 @@ class GPUModelRunner(
# maintainable loop over optimal performance.
completed_prefill_reqs = []
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
num_tokens = num_scheduled_tokens[req_id]
num_tokens = num_scheduled_tokens.get(req_id)
if num_tokens is None:
# This can happen if the request was preempted in prefill stage.
continue
# Get metadata for this request.
request = self.requests[req_id]

View File

@@ -149,9 +149,6 @@ class InputBatch:
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
@@ -256,8 +253,6 @@ class InputBatch:
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
@@ -317,7 +312,6 @@ class InputBatch:
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
# LoRA
@@ -584,10 +578,6 @@ class InputBatch:
def max_num_logprobs(self) -> int | None:
return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_prompt_logprob(self) -> bool:
return not self.num_prompt_logprobs
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0

View File

@@ -247,6 +247,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Request states.
self.requests: dict[str, CachedRequestState] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
# Initialize input batch early to avoid AttributeError in _update_states
self.input_batch = InputBatch(
@@ -420,6 +423,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
@@ -477,6 +481,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
lora_request=new_req_data.lora_request,
)
if sampling_params and sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = (
self.input_batch.vocab_size
if sampling_params.prompt_logprobs == -1
else sampling_params.prompt_logprobs
)
req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests.