[Model Runner V2] Fix pooling (#36019)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -95,8 +95,8 @@ class AsyncPoolingOutput(AsyncModelRunnerOutput):
|
||||
self.copy_event.record(copy_stream)
|
||||
|
||||
def get_output(self) -> ModelRunnerOutput:
|
||||
pooler_output = list(self.pooler_output_cpu.unbind(dim=0))
|
||||
self.copy_event.synchronize()
|
||||
pooler_output = self.pooler_output_cpu.unbind(dim=0)
|
||||
if self.is_valid_cpu is not None:
|
||||
is_valid_cpu = self.is_valid_cpu.tolist()
|
||||
for i, is_valid in enumerate(is_valid_cpu):
|
||||
|
||||
@@ -1117,7 +1117,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# The prior execute_model call must have failed.
|
||||
return None
|
||||
|
||||
input_batch, _, _, _, hidden_states, _, kv_connector_output = (
|
||||
input_batch, _, _, _, hidden_states, _, kv_connector_output, _ = (
|
||||
self.execute_model_state
|
||||
)
|
||||
self.execute_model_state = None
|
||||
|
||||
Reference in New Issue
Block a user