@@ -329,8 +329,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def update_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
for req_id in scheduler_output.preempted_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
if scheduler_output.preempted_req_ids is not None:
|
||||
for req_id in scheduler_output.preempted_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
|
||||
@@ -346,6 +347,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Add new requests.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
assert new_req_data.prompt_token_ids is not None
|
||||
assert new_req_data.prefill_token_ids is not None
|
||||
assert new_req_data.sampling_params is not None
|
||||
req_id = new_req_data.req_id
|
||||
self.req_states.add_request(
|
||||
req_id=req_id,
|
||||
@@ -398,8 +402,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Decode first, then prefill.
|
||||
# batch_idx -> req_id
|
||||
req_ids = sorted(
|
||||
scheduler_output.num_scheduled_tokens,
|
||||
key=scheduler_output.num_scheduled_tokens.get,
|
||||
scheduler_output.num_scheduled_tokens.keys(),
|
||||
key=lambda k: scheduler_output.num_scheduled_tokens[k],
|
||||
)
|
||||
num_scheduled_tokens = np.array(
|
||||
[scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
|
||||
@@ -637,9 +641,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=input_batch.req_ids,
|
||||
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
|
||||
sampled_token_ids=None,
|
||||
sampled_token_ids=None, # type: ignore
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore
|
||||
pooler_output=[],
|
||||
kv_connector_output=None,
|
||||
num_nans_in_logits=None,
|
||||
|
||||
Reference in New Issue
Block a user