[Model Runner V2] Support penalties using bin counts (#29703)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -512,7 +512,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
idx_mapping_np,
|
||||
num_scheduled_tokens,
|
||||
query_start_loc_np,
|
||||
self.req_states.prefill_token_ids,
|
||||
self.req_states.prefill_token_ids.np,
|
||||
self.req_states.num_computed_prefill_tokens,
|
||||
self.input_buffers.input_ids.np,
|
||||
)
|
||||
@@ -681,7 +681,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Handle chunked prompts.
|
||||
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
|
||||
is_prompt_chunked = pos_after_step < prompt_lens
|
||||
prefill_token_ids = self.req_states.prefill_token_ids
|
||||
prefill_token_ids = self.req_states.prefill_token_ids.np
|
||||
query_start_loc = self.input_buffers.query_start_loc.np
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
if not needs_prompt_logprobs[i]:
|
||||
@@ -756,6 +756,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
input_batch.idx_mapping,
|
||||
self.req_states.num_computed_tokens,
|
||||
self.req_states.last_sampled_tokens,
|
||||
self.req_states.output_bin_counts,
|
||||
sampled_tokens,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
@@ -785,7 +786,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
with async_barrier(self.spec_decode_event):
|
||||
self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
|
||||
self.req_states.prefill_token_ids[
|
||||
self.req_states.prefill_token_ids.np[
|
||||
idx_mapping_np,
|
||||
self.req_states.num_computed_prefill_tokens[idx_mapping_np],
|
||||
]
|
||||
@@ -896,7 +897,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# barrier to avoid race conditions.
|
||||
pos = input_batch.positions[input_batch.logits_indices]
|
||||
sampling_metadata = self.req_states.make_sampling_metadata(
|
||||
input_batch.idx_mapping_np, pos
|
||||
input_batch.idx_mapping, input_batch.idx_mapping_np, pos
|
||||
)
|
||||
if input_batch.num_draft_tokens > 0:
|
||||
sampling_metadata = self.req_states.expand_sampling_metadata(
|
||||
|
||||
Reference in New Issue
Block a user