[BugFix] Fix batch updates for pooling models (#23398)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -1489,10 +1489,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for raw_output, seq_len, prompt_len in zip(
|
||||
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
|
||||
|
||||
if seq_len == prompt_len:
|
||||
pooler_output.append(raw_output.data)
|
||||
else:
|
||||
pooler_output.append(None)
|
||||
output = raw_output.data if seq_len == prompt_len else None
|
||||
pooler_output.append(output)
|
||||
|
||||
return ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
@@ -1522,7 +1520,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Prepare the decoder inputs.
|
||||
(attn_metadata, logits_indices, spec_decode_metadata,
|
||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||
max_query_len) = (self._prepare_inputs(scheduler_output))
|
||||
max_query_len) = self._prepare_inputs(scheduler_output)
|
||||
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
|
||||
Reference in New Issue
Block a user