[Perf] Optimize redundant sync for pooling model, 3.7% Throughput Improvement (#39113)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -1046,6 +1046,13 @@ class GPUModelRunner(
|
||||
def _sync_device(self) -> None:
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
def _get_or_create_async_output_copy_stream(self) -> torch.cuda.Stream:
|
||||
stream = self.async_output_copy_stream
|
||||
if stream is None:
|
||||
stream = torch.cuda.Stream()
|
||||
self.async_output_copy_stream = stream
|
||||
return stream
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
output.
|
||||
@@ -3152,21 +3159,21 @@ class GPUModelRunner(
|
||||
model_runner_output.pooler_output = [None] * num_reqs
|
||||
return model_runner_output
|
||||
|
||||
if self.use_async_scheduling:
|
||||
return AsyncGPUPoolingModelRunnerOutput(
|
||||
model_runner_output=model_runner_output,
|
||||
if not current_platform.is_cuda_alike():
|
||||
# cpu/xpu runners cannot use the CUDA stream/event-based wrapper.
|
||||
model_runner_output.pooler_output = _copy_pooler_output_to_cpu(
|
||||
raw_pooler_output=raw_pooler_output,
|
||||
finished_mask=finished_mask,
|
||||
async_output_copy_stream=self.async_output_copy_stream,
|
||||
)
|
||||
self._sync_device()
|
||||
return model_runner_output
|
||||
|
||||
model_runner_output.pooler_output = _copy_pooler_output_to_cpu(
|
||||
return AsyncGPUPoolingModelRunnerOutput(
|
||||
model_runner_output=model_runner_output,
|
||||
raw_pooler_output=raw_pooler_output,
|
||||
finished_mask=finished_mask,
|
||||
async_output_copy_stream=self._get_or_create_async_output_copy_stream(),
|
||||
)
|
||||
self._sync_device()
|
||||
|
||||
return model_runner_output
|
||||
|
||||
def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int:
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
@@ -4333,7 +4340,7 @@ class GPUModelRunner(
|
||||
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||
logprobs_tensors=sampler_output.logprobs_tensors,
|
||||
invalid_req_indices=invalid_req_indices,
|
||||
async_output_copy_stream=self.async_output_copy_stream,
|
||||
async_output_copy_stream=self._get_or_create_async_output_copy_stream(),
|
||||
vocab_size=self.input_batch.vocab_size,
|
||||
)
|
||||
with record_function_or_nullcontext(
|
||||
|
||||
Reference in New Issue
Block a user