Replace torch.cuda.Event with torch.Event for better hardware compatibility (#26985)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -255,8 +255,8 @@ def bench_run(
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Timing
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
latencies = []
|
||||
for _ in range(num_iters):
|
||||
|
||||
@@ -185,8 +185,8 @@ def benchmark_config(
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
|
||||
@@ -105,8 +105,8 @@ def benchmark_permute(
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
@@ -241,8 +241,8 @@ def benchmark_unpermute(
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
|
||||
@@ -30,8 +30,8 @@ def _time_cuda(
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start = torch.Event(enable_timing=True)
|
||||
end = torch.Event(enable_timing=True)
|
||||
|
||||
start.record()
|
||||
for _ in range(bench_iters):
|
||||
|
||||
@@ -253,8 +253,8 @@ def benchmark(
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
# Benchmark
|
||||
latencies: list[float] = []
|
||||
|
||||
@@ -127,8 +127,8 @@ def benchmark_decode(
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
torch.cuda.synchronize()
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start = torch.Event(enable_timing=True)
|
||||
end = torch.Event(enable_timing=True)
|
||||
times = []
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
|
||||
@@ -139,8 +139,8 @@ def benchmark_prefill(
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
torch.cuda.synchronize()
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start = torch.Event(enable_timing=True)
|
||||
end = torch.Event(enable_timing=True)
|
||||
times = []
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
|
||||
@@ -183,8 +183,8 @@ def benchmark_config(
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
|
||||
@@ -150,8 +150,8 @@ def test_merge_attn_states(
|
||||
output_torch = output.clone()
|
||||
output_lse_torch = output_lse.clone()
|
||||
total_time_torch_kernel = 0
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start = torch.Event(enable_timing=True)
|
||||
end = torch.Event(enable_timing=True)
|
||||
|
||||
# 0. Run the Torch kernel
|
||||
prefix_lse_torch = prefix_lse.clone()
|
||||
@@ -188,8 +188,8 @@ def test_merge_attn_states(
|
||||
output_lse_ref_triton = output_lse.clone()
|
||||
|
||||
total_time_triton_kernel = 0
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start = torch.Event(enable_timing=True)
|
||||
end = torch.Event(enable_timing=True)
|
||||
|
||||
for _ in range(warmup_times):
|
||||
merge_attn_states_triton(
|
||||
|
||||
@@ -68,9 +68,9 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
|
||||
self.h2d_stream = torch.cuda.Stream()
|
||||
|
||||
# job_id -> transfer cuda event
|
||||
self.transfer_events: dict[int, torch.cuda.Event] = {}
|
||||
self.transfer_events: dict[int, torch.Event] = {}
|
||||
# list of cuda events available for re-use
|
||||
self.events_pool: list[torch.cuda.Event] = []
|
||||
self.events_pool: list[torch.Event] = []
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
|
||||
@@ -153,7 +153,7 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
|
||||
)
|
||||
src_to_dst_tensor = torch.from_numpy(src_to_dst)
|
||||
|
||||
event = self.events_pool.pop() if self.events_pool else torch.cuda.Event()
|
||||
event = self.events_pool.pop() if self.events_pool else torch.Event()
|
||||
with torch.cuda.stream(stream):
|
||||
for src_tensor, dst_tensor, kv_dim in zip(
|
||||
src_tensors, dst_tensors, self.kv_dim_before_num_blocks
|
||||
|
||||
@@ -96,14 +96,14 @@ def _torch_cuda_wrapper():
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
cuda_event = torch.cuda.Event
|
||||
cuda_event = torch.Event
|
||||
cuda_stream = torch.cuda.Stream
|
||||
try:
|
||||
torch.cuda.Event = _EventPlaceholder
|
||||
torch.Event = _EventPlaceholder
|
||||
torch.cuda.Stream = _StreamPlaceholder
|
||||
yield
|
||||
finally:
|
||||
torch.cuda.Event = cuda_event
|
||||
torch.Event = cuda_event
|
||||
torch.cuda.Stream = cuda_stream
|
||||
|
||||
|
||||
|
||||
@@ -265,7 +265,7 @@ class InputBatch:
|
||||
# ids from prior step, if required by current sampling params
|
||||
# (e.g. penalties).
|
||||
self.sampled_token_ids_cpu: torch.Tensor | None = None
|
||||
self.async_copy_ready_event: torch.cuda.Event | None = None
|
||||
self.async_copy_ready_event: torch.Event | None = None
|
||||
|
||||
@property
|
||||
def req_ids(self) -> list[str]:
|
||||
@@ -891,7 +891,7 @@ class InputBatch:
|
||||
def set_async_sampled_token_ids(
|
||||
self,
|
||||
sampled_token_ids_cpu: torch.Tensor,
|
||||
async_copy_ready_event: torch.cuda.Event,
|
||||
async_copy_ready_event: torch.Event,
|
||||
) -> None:
|
||||
"""
|
||||
In async scheduling case, store ref to sampled_token_ids_cpu
|
||||
|
||||
@@ -185,7 +185,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||
self._invalid_req_indices = invalid_req_indices
|
||||
|
||||
# Event on the copy stream so we can synchronize the non-blocking copy.
|
||||
self.async_copy_ready_event = torch.cuda.Event()
|
||||
self.async_copy_ready_event = torch.Event()
|
||||
|
||||
# Keep a reference to the device tensor to avoid it being
|
||||
# deallocated until we finish copying it to the host.
|
||||
@@ -435,10 +435,10 @@ class GPUModelRunner(
|
||||
self.async_output_copy_stream: torch.cuda.Stream | None = None
|
||||
# cuda event to synchronize use of reused CPU tensors between steps
|
||||
# when async scheduling is enabled.
|
||||
self.prepare_inputs_event: torch.cuda.Event | None = None
|
||||
self.prepare_inputs_event: torch.Event | None = None
|
||||
if self.use_async_scheduling:
|
||||
self.async_output_copy_stream = torch.cuda.Stream()
|
||||
self.prepare_inputs_event = torch.cuda.Event()
|
||||
self.prepare_inputs_event = torch.Event()
|
||||
|
||||
# self.cudagraph_batch_sizes sorts in ascending order.
|
||||
if (
|
||||
@@ -549,7 +549,7 @@ class GPUModelRunner(
|
||||
|
||||
# Cached outputs.
|
||||
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
|
||||
self.transfer_event = torch.cuda.Event()
|
||||
self.transfer_event = torch.Event()
|
||||
self.sampled_token_ids_pinned_cpu = torch.empty(
|
||||
(self.max_num_reqs, 1),
|
||||
dtype=torch.int64,
|
||||
@@ -559,10 +559,10 @@ class GPUModelRunner(
|
||||
|
||||
# Pre-allocated tensor for copying valid sampled token counts to CPU,
|
||||
# with dedicated stream for overlapping and event for coordination.
|
||||
self.valid_sampled_token_count_event: torch.cuda.Event | None = None
|
||||
self.valid_sampled_token_count_event: torch.Event | None = None
|
||||
self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None
|
||||
if self.use_async_scheduling and self.num_spec_tokens:
|
||||
self.valid_sampled_token_count_event = torch.cuda.Event()
|
||||
self.valid_sampled_token_count_event = torch.Event()
|
||||
self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
|
||||
self.valid_sampled_token_count_cpu = torch.empty(
|
||||
self.max_num_reqs,
|
||||
|
||||
@@ -27,8 +27,8 @@ class UBatchContext:
|
||||
ready_barrier: threading.Barrier,
|
||||
cpu_wait_event: threading.Event,
|
||||
cpu_signal_event: threading.Event,
|
||||
gpu_comm_done_event: torch.cuda.Event,
|
||||
gpu_compute_done_event: torch.cuda.Event,
|
||||
gpu_comm_done_event: torch.Event,
|
||||
gpu_compute_done_event: torch.Event,
|
||||
schedule: str = "default",
|
||||
):
|
||||
self.id = id
|
||||
@@ -207,8 +207,8 @@ def make_ubatch_contexts(
|
||||
Create a context manager for micro-batching synchronization.
|
||||
"""
|
||||
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
|
||||
gpu_comm_done_events = [torch.cuda.Event() for _ in range(num_micro_batches)]
|
||||
gpu_compute_done_events = [torch.cuda.Event() for _ in range(num_micro_batches)]
|
||||
gpu_comm_done_events = [torch.Event() for _ in range(num_micro_batches)]
|
||||
gpu_compute_done_events = [torch.Event() for _ in range(num_micro_batches)]
|
||||
|
||||
assert len(forward_contexts) == 2
|
||||
|
||||
|
||||
@@ -37,19 +37,12 @@ class XPUModelRunner(GPUModelRunner):
|
||||
|
||||
@contextmanager
|
||||
def _torch_cuda_wrapper():
|
||||
class _EventPlaceholder:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self.record = lambda: None
|
||||
self.synchronize = lambda: None
|
||||
|
||||
try:
|
||||
# replace cuda APIs with xpu APIs, this should work by default
|
||||
torch.cuda.Event = torch.xpu.Event
|
||||
torch.cuda.Stream = torch.xpu.Stream
|
||||
torch.cuda.default_stream = torch.xpu.current_stream
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
torch.cuda.stream = torch.xpu.stream
|
||||
yield
|
||||
finally:
|
||||
# if anything goes wrong, just patch it with a placeholder
|
||||
torch.cuda.Event = _EventPlaceholder
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user