[BugFix] Fix chunked prompt logprobs + preemption (#29071)
This commit is contained in:
@@ -219,9 +219,6 @@ class InputBatch:
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
self.num_logprobs: dict[str, int] = {}
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
|
||||
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
||||
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
||||
@@ -385,12 +382,6 @@ class InputBatch:
|
||||
if sampling_params.logprobs == -1
|
||||
else sampling_params.logprobs
|
||||
)
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = (
|
||||
self.vocab_size
|
||||
if sampling_params.prompt_logprobs == -1
|
||||
else sampling_params.prompt_logprobs
|
||||
)
|
||||
|
||||
if sampling_params.allowed_token_ids:
|
||||
self.has_allowed_token_ids.add(req_id)
|
||||
@@ -488,7 +479,6 @@ class InputBatch:
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||
|
||||
self.has_allowed_token_ids.discard(req_id)
|
||||
@@ -972,10 +962,6 @@ class InputBatch:
|
||||
def max_num_logprobs(self) -> int | None:
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
||||
|
||||
@property
|
||||
def no_prompt_logprob(self) -> bool:
|
||||
return not self.num_prompt_logprobs
|
||||
|
||||
@property
|
||||
def no_allowed_token_ids(self) -> bool:
|
||||
return len(self.has_allowed_token_ids) == 0
|
||||
|
||||
@@ -393,6 +393,9 @@ class GPUModelRunner(
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
self.comm_stream = torch.cuda.Stream()
|
||||
|
||||
# Input Batch
|
||||
@@ -687,6 +690,7 @@ class GPUModelRunner(
|
||||
# Remove finished requests from the cached states.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
# Remove the finished requests from the persistent batch.
|
||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||
# scheduled_req_ids overlap. This happens when a request is aborted and
|
||||
@@ -755,6 +759,13 @@ class GPUModelRunner(
|
||||
)
|
||||
self.requests[req_id] = req_state
|
||||
|
||||
if sampling_params and sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = (
|
||||
self.input_batch.vocab_size
|
||||
if sampling_params.prompt_logprobs == -1
|
||||
else sampling_params.prompt_logprobs
|
||||
)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
self._init_mrope_positions(req_state)
|
||||
@@ -2671,7 +2682,7 @@ class GPUModelRunner(
|
||||
scheduler_output, self.vllm_config
|
||||
)
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
assert not self.input_batch.num_prompt_logprobs, (
|
||||
assert not self.num_prompt_logprobs, (
|
||||
"--kv-sharing-fast-prefill produces incorrect "
|
||||
"logprobs for prompt tokens, tokens, please disable "
|
||||
"it when the requests need prompt logprobs"
|
||||
@@ -3436,7 +3447,7 @@ class GPUModelRunner(
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
) -> dict[str, LogprobsTensors | None]:
|
||||
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
|
||||
num_prompt_logprobs_dict = self.num_prompt_logprobs
|
||||
if not num_prompt_logprobs_dict:
|
||||
return {}
|
||||
|
||||
@@ -3447,7 +3458,10 @@ class GPUModelRunner(
|
||||
# maintainable loop over optimal performance.
|
||||
completed_prefill_reqs = []
|
||||
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
|
||||
num_tokens = num_scheduled_tokens[req_id]
|
||||
num_tokens = num_scheduled_tokens.get(req_id)
|
||||
if num_tokens is None:
|
||||
# This can happen if the request was preempted in prefill stage.
|
||||
continue
|
||||
|
||||
# Get metadata for this request.
|
||||
request = self.requests[req_id]
|
||||
|
||||
@@ -149,9 +149,6 @@ class InputBatch:
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
self.num_logprobs: dict[str, int] = {}
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
|
||||
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
||||
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
||||
@@ -256,8 +253,6 @@ class InputBatch:
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
|
||||
if sampling_params.logit_bias is not None:
|
||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||
|
||||
@@ -317,7 +312,6 @@ class InputBatch:
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||
|
||||
# LoRA
|
||||
@@ -584,10 +578,6 @@ class InputBatch:
|
||||
def max_num_logprobs(self) -> int | None:
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
||||
|
||||
@property
|
||||
def no_prompt_logprob(self) -> bool:
|
||||
return not self.num_prompt_logprobs
|
||||
|
||||
@property
|
||||
def no_allowed_token_ids(self) -> bool:
|
||||
return len(self.has_allowed_token_ids) == 0
|
||||
|
||||
@@ -247,6 +247,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
|
||||
# Initialize input batch early to avoid AttributeError in _update_states
|
||||
self.input_batch = InputBatch(
|
||||
@@ -420,6 +423,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Remove finished requests from the cached states.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
|
||||
# Remove the finished requests from the persistent batch.
|
||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||
@@ -477,6 +481,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
lora_request=new_req_data.lora_request,
|
||||
)
|
||||
|
||||
if sampling_params and sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = (
|
||||
self.input_batch.vocab_size
|
||||
if sampling_params.prompt_logprobs == -1
|
||||
else sampling_params.prompt_logprobs
|
||||
)
|
||||
|
||||
req_ids_to_add.append(req_id)
|
||||
|
||||
# Update the states of the running/resumed requests.
|
||||
|
||||
Reference in New Issue
Block a user