[BugFix] Fix batch updates for pooling models (#23398)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-08-22 17:20:41 -07:00
committed by GitHub
parent 24d0c9e6ed
commit c80c53a30f
3 changed files with 95 additions and 79 deletions

View File

@@ -1489,10 +1489,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for raw_output, seq_len, prompt_len in zip(
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
if seq_len == prompt_len:
pooler_output.append(raw_output.data)
else:
pooler_output.append(None)
output = raw_output.data if seq_len == prompt_len else None
pooler_output.append(output)
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
@@ -1522,7 +1520,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Prepare the decoder inputs.
(attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
max_query_len) = (self._prepare_inputs(scheduler_output))
max_query_len) = self._prepare_inputs(scheduler_output)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE