[Hardware] Replace torch.cuda.empty_cache with torch.accelerator.empty_cache (#30681)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Kunshang Ji
2026-03-04 17:49:47 +08:00
committed by GitHub
parent 5dc3538736
commit 16d2ad1d38
35 changed files with 110 additions and 59 deletions

View File

@@ -179,7 +179,7 @@ def test_ngram_and_suffix_correctness(
)
evaluate_llm_for_gsm8k(spec_llm)
del spec_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
@@ -240,7 +240,7 @@ def test_suffix_decoding_acceptance(
assert last_accept_rate > 0.80
del spec_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
@@ -307,14 +307,14 @@ def test_speculators_model_integration(
verifier_model = spec_llm.llm_engine.vllm_config.model_config.model
del spec_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
# Second run: Reference without speculative decoding
ref_llm = LLM(model=verifier_model, max_model_len=4096)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
# Compare outputs
@@ -410,7 +410,7 @@ def _run_eagle_correctness(
)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
spec_llm = LLM(
@@ -445,7 +445,7 @@ def _run_eagle_correctness(
assert matches > int(0.6 * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
@@ -715,7 +715,7 @@ def test_mtp_correctness(
ref_llm, expected_accuracy_threshold=expected_accuracy_threshold
)
del ref_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
spec_llm = LLM(
@@ -747,7 +747,7 @@ def test_mtp_correctness(
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
@@ -952,7 +952,7 @@ def assert_draft_model_correctness(args: ArgsTest):
)
del spec_llm # CLEANUP
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
print(