[CPU] Update custom ops for the CPU backend (#20255)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2025-07-01 15:25:03 +08:00
committed by GitHub
parent 9909726d2a
commit 6cc1e7d96d
23 changed files with 5357 additions and 101 deletions

View File

@@ -9,6 +9,7 @@ import torch
import torch.nn as nn
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import envs
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
@@ -27,6 +28,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
RowvLLMParameter)
# yapf: enable
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
logger = init_logger(__name__)
@@ -195,12 +197,33 @@ class UnquantizedLinearMethod(LinearMethodBase):
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL:
N, K = layer.weight.size()
dtype = layer.weight.dtype
if (torch._C._cpu._is_amx_tile_supported()
and dtype == torch.bfloat16 and N % 32 == 0
and K % 32 == 0):
packed_weight = torch.ops._C.convert_weight_packed(
layer.weight)
assert packed_weight.size() == layer.weight.size()
layer.weight.copy_(packed_weight)
if layer.bias is not None:
layer.bias = Parameter(layer.bias.to(torch.float32),
requires_grad=False)
layer.use_cpu_sgl = True
else:
logger.warning(
"CPU SGL kernels require Intel AMX support,"
" bfloat16 weight, IC and OC are divisible by 32.")
layer.use_cpu_sgl = False
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return dispatch_unquantized_gemm()(x, layer.weight, bias)
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
class LinearBase(torch.nn.Module):