[BugFix] Fix async scheduling for pooling models (#31584)
Signed-off-by: njhill <nickhill123@gmail.com>
This commit is contained in:
@@ -575,7 +575,7 @@ class Worker(WorkerBase):
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> ModelRunnerOutput | None:
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
@@ -624,7 +624,9 @@ class Worker(WorkerBase):
|
||||
output = self.model_runner.execute_model(
|
||||
scheduler_output, intermediate_tensors
|
||||
)
|
||||
if isinstance(output, ModelRunnerOutput | NoneType):
|
||||
if isinstance(
|
||||
output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType
|
||||
):
|
||||
return output
|
||||
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
|
||||
Reference in New Issue
Block a user