[Model Runner V2] Support penalties using bin counts (#29703)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-11-28 17:53:17 -08:00
committed by GitHub
parent ea3370b428
commit 1dcafb3dea
5 changed files with 280 additions and 14 deletions

View File

@@ -512,7 +512,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
idx_mapping_np,
num_scheduled_tokens,
query_start_loc_np,
self.req_states.prefill_token_ids,
self.req_states.prefill_token_ids.np,
self.req_states.num_computed_prefill_tokens,
self.input_buffers.input_ids.np,
)
@@ -681,7 +681,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Handle chunked prompts.
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
is_prompt_chunked = pos_after_step < prompt_lens
prefill_token_ids = self.req_states.prefill_token_ids
prefill_token_ids = self.req_states.prefill_token_ids.np
query_start_loc = self.input_buffers.query_start_loc.np
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
@@ -756,6 +756,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch.idx_mapping,
self.req_states.num_computed_tokens,
self.req_states.last_sampled_tokens,
self.req_states.output_bin_counts,
sampled_tokens,
num_sampled,
num_rejected,
@@ -785,7 +786,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
idx_mapping_np = input_batch.idx_mapping_np
with async_barrier(self.spec_decode_event):
self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
self.req_states.prefill_token_ids[
self.req_states.prefill_token_ids.np[
idx_mapping_np,
self.req_states.num_computed_prefill_tokens[idx_mapping_np],
]
@@ -896,7 +897,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# barrier to avoid race conditions.
pos = input_batch.positions[input_batch.logits_indices]
sampling_metadata = self.req_states.make_sampling_metadata(
input_batch.idx_mapping_np, pos
input_batch.idx_mapping, input_batch.idx_mapping_np, pos
)
if input_batch.num_draft_tokens > 0:
sampling_metadata = self.req_states.expand_sampling_metadata(