Replace torch.cuda.Event with torch.Event for better hardware compatibility (#26985)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2025-11-19 03:34:36 +08:00
committed by GitHub
parent c3e2978620
commit 2a2d5d2780
15 changed files with 41 additions and 48 deletions

View File

@@ -255,8 +255,8 @@ def bench_run(
torch.cuda.synchronize()
# Timing
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True)
latencies = []
for _ in range(num_iters):

View File

@@ -185,8 +185,8 @@ def benchmark_config(
graph.replay()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):

View File

@@ -105,8 +105,8 @@ def benchmark_permute(
graph.replay()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
@@ -241,8 +241,8 @@ def benchmark_unpermute(
graph.replay()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):

View File

@@ -30,8 +30,8 @@ def _time_cuda(
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start = torch.Event(enable_timing=True)
end = torch.Event(enable_timing=True)
start.record()
for _ in range(bench_iters):

View File

@@ -253,8 +253,8 @@ def benchmark(
)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True)
# Benchmark
latencies: list[float] = []

View File

@@ -127,8 +127,8 @@ def benchmark_decode(
def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start = torch.Event(enable_timing=True)
end = torch.Event(enable_timing=True)
times = []
for i in range(warmup):
fn()

View File

@@ -139,8 +139,8 @@ def benchmark_prefill(
def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start = torch.Event(enable_timing=True)
end = torch.Event(enable_timing=True)
times = []
for i in range(warmup):
fn()

View File

@@ -183,8 +183,8 @@ def benchmark_config(
run()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):

View File

@@ -150,8 +150,8 @@ def test_merge_attn_states(
output_torch = output.clone()
output_lse_torch = output_lse.clone()
total_time_torch_kernel = 0
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start = torch.Event(enable_timing=True)
end = torch.Event(enable_timing=True)
# 0. Run the Torch kernel
prefix_lse_torch = prefix_lse.clone()
@@ -188,8 +188,8 @@ def test_merge_attn_states(
output_lse_ref_triton = output_lse.clone()
total_time_triton_kernel = 0
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start = torch.Event(enable_timing=True)
end = torch.Event(enable_timing=True)
for _ in range(warmup_times):
merge_attn_states_triton(

View File

@@ -68,9 +68,9 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
self.h2d_stream = torch.cuda.Stream()
# job_id -> transfer cuda event
self.transfer_events: dict[int, torch.cuda.Event] = {}
self.transfer_events: dict[int, torch.Event] = {}
# list of cuda events available for re-use
self.events_pool: list[torch.cuda.Event] = []
self.events_pool: list[torch.Event] = []
pin_memory = is_pin_memory_available()
@@ -153,7 +153,7 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
)
src_to_dst_tensor = torch.from_numpy(src_to_dst)
event = self.events_pool.pop() if self.events_pool else torch.cuda.Event()
event = self.events_pool.pop() if self.events_pool else torch.Event()
with torch.cuda.stream(stream):
for src_tensor, dst_tensor, kv_dim in zip(
src_tensors, dst_tensors, self.kv_dim_before_num_blocks

View File

@@ -96,14 +96,14 @@ def _torch_cuda_wrapper():
def __init__(self, *args, **kwargs) -> None:
pass
cuda_event = torch.cuda.Event
cuda_event = torch.Event
cuda_stream = torch.cuda.Stream
try:
torch.cuda.Event = _EventPlaceholder
torch.Event = _EventPlaceholder
torch.cuda.Stream = _StreamPlaceholder
yield
finally:
torch.cuda.Event = cuda_event
torch.Event = cuda_event
torch.cuda.Stream = cuda_stream

View File

@@ -265,7 +265,7 @@ class InputBatch:
# ids from prior step, if required by current sampling params
# (e.g. penalties).
self.sampled_token_ids_cpu: torch.Tensor | None = None
self.async_copy_ready_event: torch.cuda.Event | None = None
self.async_copy_ready_event: torch.Event | None = None
@property
def req_ids(self) -> list[str]:
@@ -891,7 +891,7 @@ class InputBatch:
def set_async_sampled_token_ids(
self,
sampled_token_ids_cpu: torch.Tensor,
async_copy_ready_event: torch.cuda.Event,
async_copy_ready_event: torch.Event,
) -> None:
"""
In async scheduling case, store ref to sampled_token_ids_cpu

View File

@@ -185,7 +185,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
self._invalid_req_indices = invalid_req_indices
# Event on the copy stream so we can synchronize the non-blocking copy.
self.async_copy_ready_event = torch.cuda.Event()
self.async_copy_ready_event = torch.Event()
# Keep a reference to the device tensor to avoid it being
# deallocated until we finish copying it to the host.
@@ -435,10 +435,10 @@ class GPUModelRunner(
self.async_output_copy_stream: torch.cuda.Stream | None = None
# cuda event to synchronize use of reused CPU tensors between steps
# when async scheduling is enabled.
self.prepare_inputs_event: torch.cuda.Event | None = None
self.prepare_inputs_event: torch.Event | None = None
if self.use_async_scheduling:
self.async_output_copy_stream = torch.cuda.Stream()
self.prepare_inputs_event = torch.cuda.Event()
self.prepare_inputs_event = torch.Event()
# self.cudagraph_batch_sizes sorts in ascending order.
if (
@@ -549,7 +549,7 @@ class GPUModelRunner(
# Cached outputs.
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
self.transfer_event = torch.cuda.Event()
self.transfer_event = torch.Event()
self.sampled_token_ids_pinned_cpu = torch.empty(
(self.max_num_reqs, 1),
dtype=torch.int64,
@@ -559,10 +559,10 @@ class GPUModelRunner(
# Pre-allocated tensor for copying valid sampled token counts to CPU,
# with dedicated stream for overlapping and event for coordination.
self.valid_sampled_token_count_event: torch.cuda.Event | None = None
self.valid_sampled_token_count_event: torch.Event | None = None
self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None
if self.use_async_scheduling and self.num_spec_tokens:
self.valid_sampled_token_count_event = torch.cuda.Event()
self.valid_sampled_token_count_event = torch.Event()
self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
self.valid_sampled_token_count_cpu = torch.empty(
self.max_num_reqs,

View File

@@ -27,8 +27,8 @@ class UBatchContext:
ready_barrier: threading.Barrier,
cpu_wait_event: threading.Event,
cpu_signal_event: threading.Event,
gpu_comm_done_event: torch.cuda.Event,
gpu_compute_done_event: torch.cuda.Event,
gpu_comm_done_event: torch.Event,
gpu_compute_done_event: torch.Event,
schedule: str = "default",
):
self.id = id
@@ -207,8 +207,8 @@ def make_ubatch_contexts(
Create a context manager for micro-batching synchronization.
"""
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
gpu_comm_done_events = [torch.cuda.Event() for _ in range(num_micro_batches)]
gpu_compute_done_events = [torch.cuda.Event() for _ in range(num_micro_batches)]
gpu_comm_done_events = [torch.Event() for _ in range(num_micro_batches)]
gpu_compute_done_events = [torch.Event() for _ in range(num_micro_batches)]
assert len(forward_contexts) == 2

View File

@@ -37,19 +37,12 @@ class XPUModelRunner(GPUModelRunner):
@contextmanager
def _torch_cuda_wrapper():
class _EventPlaceholder:
def __init__(self, *args, **kwargs) -> None:
self.record = lambda: None
self.synchronize = lambda: None
try:
# replace cuda APIs with xpu APIs, this should work by default
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.default_stream = torch.xpu.current_stream
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.stream = torch.xpu.stream
yield
finally:
# if anything goes wrong, just patch it with a placeholder
torch.cuda.Event = _EventPlaceholder
pass