[Hardware] Replace torch.cuda.synchronize() api with torch.accelerator.synchronize (#36085)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -135,14 +135,14 @@ def benchmark_mrope(
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
# Time reference implementation
|
||||
torch_times = []
|
||||
for _ in range(benchmark_iter):
|
||||
query_clone = query.clone()
|
||||
key_clone = key.clone()
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
mrope_helper_class.forward_native(
|
||||
@@ -151,7 +151,7 @@ def benchmark_mrope(
|
||||
key_clone,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
torch_times.append(time.time() - start_time)
|
||||
|
||||
# Time triton kernel implementation
|
||||
@@ -159,14 +159,14 @@ def benchmark_mrope(
|
||||
for _ in range(benchmark_iter):
|
||||
query_clone = query.clone()
|
||||
key_clone = key.clone()
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
start_time = time.time()
|
||||
mrope_helper_class.forward_cuda(
|
||||
positions,
|
||||
query_clone,
|
||||
key_clone,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
triton_times.append(time.time() - start_time)
|
||||
|
||||
# Calculate statistics
|
||||
|
||||
Reference in New Issue
Block a user