[BugFix] Fix async scheduling for pooling models (#31584)

Signed-off-by: njhill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2025-12-31 14:48:51 -08:00
committed by GitHub
parent d8da76f3b7
commit 6c2cfb62ff
7 changed files with 132 additions and 93 deletions

View File

@@ -575,7 +575,7 @@ class Worker(WorkerBase):
@torch.inference_mode()
def execute_model(
self, scheduler_output: "SchedulerOutput"
) -> ModelRunnerOutput | None:
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@@ -624,7 +624,9 @@ class Worker(WorkerBase):
output = self.model_runner.execute_model(
scheduler_output, intermediate_tensors
)
if isinstance(output, ModelRunnerOutput | NoneType):
if isinstance(
output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
):
return output
assert isinstance(output, IntermediateTensors)