[Model Runner V2] Fix pooling (#36019)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-03-04 10:53:17 -08:00
committed by GitHub
parent 7faba503c4
commit 417fd28fb1
2 changed files with 2 additions and 2 deletions

View File

@@ -95,8 +95,8 @@ class AsyncPoolingOutput(AsyncModelRunnerOutput):
self.copy_event.record(copy_stream)
def get_output(self) -> ModelRunnerOutput:
pooler_output = list(self.pooler_output_cpu.unbind(dim=0))
self.copy_event.synchronize()
pooler_output = self.pooler_output_cpu.unbind(dim=0)
if self.is_valid_cpu is not None:
is_valid_cpu = self.is_valid_cpu.tolist()
for i, is_valid in enumerate(is_valid_cpu):

View File

@@ -1117,7 +1117,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# The prior execute_model call must have failed.
return None
input_batch, _, _, _, hidden_states, _, kv_connector_output = (
input_batch, _, _, _, hidden_states, _, kv_connector_output, _ = (
self.execute_model_state
)
self.execute_model_state = None