[Kernel] some optimizations for dense marlin and moe marlin (#16850)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
@@ -8,7 +8,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
|
||||
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
|
||||
marlin_make_workspace_new, marlin_permute_scales, marlin_sort_g_idx,
|
||||
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
@@ -53,8 +53,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||
|
||||
# Allocate marlin workspace.
|
||||
self.workspace = marlin_make_workspace(c.partition_weight_shape[1],
|
||||
device)
|
||||
self.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# Default names since marlin requires empty parameters for these,
|
||||
# TODO: remove this requirement from marlin (allow optional tensors)
|
||||
@@ -127,6 +126,5 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
wtype=c.weight_type,
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
has_zp=self.config.zero_points,
|
||||
is_k_full=self.is_k_full,
|
||||
bias=bias)
|
||||
|
||||
Reference in New Issue
Block a user