[CPU] Refactor CPU W8A8 scaled_mm (#23071)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@@ -199,11 +199,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL:
|
||||
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
||||
N, K = layer.weight.size()
|
||||
dtype = layer.weight.dtype
|
||||
if (torch._C._cpu._is_amx_tile_supported()
|
||||
and dtype == torch.bfloat16 and N % 32 == 0
|
||||
and K % 32 == 0):
|
||||
if check_cpu_sgl_kernel(N, K, dtype):
|
||||
packed_weight = torch.ops._C.convert_weight_packed(
|
||||
layer.weight)
|
||||
assert packed_weight.size() == layer.weight.size()
|
||||
@@ -215,7 +214,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
else:
|
||||
logger.warning(
|
||||
"CPU SGL kernels require Intel AMX support,"
|
||||
" bfloat16 weight, IC and OC are divisible by 32.")
|
||||
" bf16/fp16/int8 weight, IC and OC are divisible by "
|
||||
"32 and 16.")
|
||||
layer.use_cpu_sgl = False
|
||||
|
||||
def apply(self,
|
||||
|
||||
Reference in New Issue
Block a user