[Hardware] Replace torch.cuda.empty_cache with torch.accelerator.empty_cache (#30681)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Kunshang Ji <jikunshang95@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -200,7 +200,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
):
|
||||
num_pad = 256 // weight.element_size()
|
||||
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
|
||||
torch.cuda.empty_cache()
|
||||
torch.accelerator.empty_cache()
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
@@ -961,7 +961,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
|
||||
# secondly, process mxfp weights
|
||||
if self.emulate:
|
||||
torch.cuda.empty_cache()
|
||||
torch.accelerator.empty_cache()
|
||||
return
|
||||
|
||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||
@@ -995,7 +995,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
layer.w13_weight.is_shuffled = True
|
||||
layer.w2_weight.is_shuffled = True
|
||||
torch.cuda.empty_cache()
|
||||
torch.accelerator.empty_cache()
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
@@ -1116,7 +1116,7 @@ class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
torch.cuda.empty_cache()
|
||||
torch.accelerator.empty_cache()
|
||||
|
||||
if self.static_input_scales:
|
||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
||||
|
||||
@@ -1407,7 +1407,7 @@ def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
|
||||
import torch.nn.functional as F
|
||||
|
||||
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
|
||||
torch.cuda.empty_cache()
|
||||
torch.accelerator.empty_cache()
|
||||
return weight
|
||||
|
||||
|
||||
|
||||
@@ -811,7 +811,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
**stacked_quant_state_dict,
|
||||
}
|
||||
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
|
||||
torch.cuda.empty_cache()
|
||||
torch.accelerator.empty_cache()
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
Reference in New Issue
Block a user