[CPU] Refactor CPU W8A8 scaled_mm (#23071)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2025-08-21 09:34:24 +08:00
committed by GitHub
parent b029de9902
commit 7be5d113d8
17 changed files with 1525 additions and 1273 deletions

View File

@@ -199,11 +199,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL:
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
N, K = layer.weight.size()
dtype = layer.weight.dtype
if (torch._C._cpu._is_amx_tile_supported()
and dtype == torch.bfloat16 and N % 32 == 0
and K % 32 == 0):
if check_cpu_sgl_kernel(N, K, dtype):
packed_weight = torch.ops._C.convert_weight_packed(
layer.weight)
assert packed_weight.size() == layer.weight.size()
@@ -215,7 +214,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
else:
logger.warning(
"CPU SGL kernels require Intel AMX support,"
" bfloat16 weight, IC and OC are divisible by 32.")
" bf16/fp16/int8 weight, IC and OC are divisible by "
"32 and 16.")
layer.use_cpu_sgl = False
def apply(self,