[CPU] Update custom ops for the CPU backend (#20255)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@@ -9,6 +9,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm import envs
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
@@ -27,6 +28,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
RowvLLMParameter)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -195,12 +197,33 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL:
|
||||
N, K = layer.weight.size()
|
||||
dtype = layer.weight.dtype
|
||||
if (torch._C._cpu._is_amx_tile_supported()
|
||||
and dtype == torch.bfloat16 and N % 32 == 0
|
||||
and K % 32 == 0):
|
||||
packed_weight = torch.ops._C.convert_weight_packed(
|
||||
layer.weight)
|
||||
assert packed_weight.size() == layer.weight.size()
|
||||
layer.weight.copy_(packed_weight)
|
||||
if layer.bias is not None:
|
||||
layer.bias = Parameter(layer.bias.to(torch.float32),
|
||||
requires_grad=False)
|
||||
layer.use_cpu_sgl = True
|
||||
else:
|
||||
logger.warning(
|
||||
"CPU SGL kernels require Intel AMX support,"
|
||||
" bfloat16 weight, IC and OC are divisible by 32.")
|
||||
layer.use_cpu_sgl = False
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
return dispatch_unquantized_gemm()(x, layer.weight, bias)
|
||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||
|
||||
|
||||
class LinearBase(torch.nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user