[Hardware] Replace torch.cuda.empty_cache with torch.accelerator.empty_cache (#30681)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Kunshang Ji
2026-03-04 17:49:47 +08:00
committed by GitHub
parent 5dc3538736
commit 16d2ad1d38
35 changed files with 110 additions and 59 deletions

View File

@@ -200,7 +200,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
):
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
return weight

View File

@@ -961,7 +961,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
# secondly, process mxfp weights
if self.emulate:
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
return
from aiter.utility.fp4_utils import e8m0_shuffle
@@ -995,7 +995,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
layer.w13_weight.is_shuffled = True
layer.w2_weight.is_shuffled = True
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
@@ -1116,7 +1116,7 @@ class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
del layer.w2_weight
layer.w13_weight = None
layer.w2_weight = None
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
if self.static_input_scales:
if layer.w13_input_scale is None or layer.w2_input_scale is None:

View File

@@ -1407,7 +1407,7 @@ def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
import torch.nn.functional as F
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
return weight

View File

@@ -811,7 +811,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
**stacked_quant_state_dict,
}
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)