[Model Runner V2] Minor CPU optimizations (#34856)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -14,6 +15,7 @@ class AsyncOutput(AsyncModelRunnerOutput):
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
sampler_output: SamplerOutput,
|
||||
num_sampled_tokens: torch.Tensor,
|
||||
main_stream: torch.cuda.Stream,
|
||||
copy_stream: torch.cuda.Stream,
|
||||
copy_event: torch.cuda.Event,
|
||||
):
|
||||
@@ -25,9 +27,8 @@ class AsyncOutput(AsyncModelRunnerOutput):
|
||||
self.num_sampled_tokens = num_sampled_tokens
|
||||
self.copy_event = copy_event
|
||||
|
||||
default_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(copy_stream):
|
||||
copy_stream.wait_stream(default_stream)
|
||||
with stream(copy_stream, main_stream):
|
||||
copy_stream.wait_stream(main_stream)
|
||||
|
||||
self.sampled_token_ids = async_copy_to_np(sampler_output.sampled_token_ids)
|
||||
self.logprobs_tensors: LogprobsTensors | None = None
|
||||
@@ -71,3 +72,15 @@ class AsyncOutput(AsyncModelRunnerOutput):
|
||||
|
||||
def async_copy_to_np(x: torch.Tensor) -> np.ndarray:
|
||||
return x.to("cpu", non_blocking=True).numpy()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def stream(to_stream: torch.cuda.Stream, from_stream: torch.cuda.Stream):
|
||||
"""Lightweight version of torch.cuda.stream() context manager which
|
||||
avoids current_stream and device lookups.
|
||||
"""
|
||||
try:
|
||||
torch.cuda.set_stream(to_stream)
|
||||
yield
|
||||
finally:
|
||||
torch.cuda.set_stream(from_stream)
|
||||
|
||||
Reference in New Issue
Block a user