[V1][Quantization] Add CUDA graph compatible v1 GGUF support (#18646)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -9,7 +9,6 @@ from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
@@ -19,6 +18,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -96,8 +96,8 @@ MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
||||
|
||||
|
||||
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
|
||||
qweight_type: int) -> torch.Tensor:
|
||||
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
|
||||
qweight_type: int) -> torch.Tensor:
|
||||
# HACK: when doing chunked prefill we don't generate output tokens
|
||||
# so input to logits generator is empty which causes invalid parameter
|
||||
if x.shape[0] == 0:
|
||||
@@ -130,6 +130,30 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
|
||||
return y
|
||||
|
||||
|
||||
def _fused_mul_mat_gguf_fake(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
qweight_type: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(x.shape[0],
|
||||
qweight.shape[0],
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="_fused_mul_mat_gguf",
|
||||
op_func=_fused_mul_mat_gguf,
|
||||
mutates_args=[],
|
||||
fake_impl=_fused_mul_mat_gguf_fake,
|
||||
)
|
||||
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
def _fused_moe_gguf(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -138,8 +162,21 @@ def _fused_moe_gguf(
|
||||
topk_ids: torch.Tensor,
|
||||
qweight_type: int,
|
||||
qweight_type2: int,
|
||||
act,
|
||||
activation: str,
|
||||
) -> torch.Tensor:
|
||||
|
||||
def act(x: torch.Tensor):
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(out, x)
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(out, x)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation}")
|
||||
return out
|
||||
|
||||
# lazy import to avoid triggering triton import in CPU backend
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
moe_align_block_size)
|
||||
@@ -189,12 +226,12 @@ def _fused_moe_gguf(
|
||||
for ww, ii in zip(w, idx):
|
||||
expert_up = w1[ii]
|
||||
|
||||
out = _fuse_mul_mat(inp, expert_up, qweight_type)
|
||||
out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
|
||||
out = act(out)
|
||||
|
||||
expert_down = w2[ii]
|
||||
current_state = _fuse_mul_mat(out, expert_down,
|
||||
qweight_type2).mul_(ww)
|
||||
current_state = fused_mul_mat_gguf(out, expert_down,
|
||||
qweight_type2).mul_(ww)
|
||||
if current_hidden_state is None:
|
||||
current_hidden_state = current_state
|
||||
else:
|
||||
@@ -203,6 +240,78 @@ def _fused_moe_gguf(
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
def _fused_moe_gguf_fake(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
qweight_type: int,
|
||||
qweight_type2: int,
|
||||
activation: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="_fused_moe_gguf",
|
||||
op_func=_fused_moe_gguf,
|
||||
mutates_args=[],
|
||||
fake_impl=_fused_moe_gguf_fake,
|
||||
)
|
||||
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
def _apply_gguf_embedding(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
qweight_type: int,
|
||||
hidden_size: int,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return torch.embedding(qweight, x)
|
||||
elif qweight_type in DEQUANT_TYPES:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
x_flat = x.flatten()
|
||||
assert (hidden_size == qweight.shape[1] // type_size * block_size)
|
||||
quant = torch.index_select(qweight, dim=0, index=x_flat)
|
||||
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
|
||||
x_flat.shape[0], dtype)
|
||||
return dequant.view(*x.shape, hidden_size)
|
||||
else:
|
||||
qweight_type = WeightType(qweight_type)
|
||||
raise NotImplementedError(
|
||||
f"Unsupported GGUF quantization type: {qweight_type}")
|
||||
|
||||
|
||||
def _apply_gguf_embedding_fake(
|
||||
x: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
qweight_type: int,
|
||||
hidden_size: int,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="_apply_gguf_embedding",
|
||||
op_func=_apply_gguf_embedding,
|
||||
mutates_args=[],
|
||||
fake_impl=_apply_gguf_embedding_fake,
|
||||
)
|
||||
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
class GGUFLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GGUF.
|
||||
|
||||
@@ -249,26 +358,76 @@ class GGUFLinearMethod(LinearMethodBase):
|
||||
set_weight_attrs(qweight_type, extra_weight_attrs)
|
||||
layer.register_parameter("qweight_type", qweight_type)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
qweight_type = layer.qweight_type.weight_type
|
||||
if not (qweight_type in UNQUANTIZED_TYPES
|
||||
or qweight_type in DEQUANT_TYPES):
|
||||
qweight_type = WeightType(qweight_type)
|
||||
raise ValueError(
|
||||
f"Unsupported GGUF quantization type {qweight_type} in "
|
||||
f"layer {layer}.")
|
||||
# For MergedColumnParallelLinear and QKVParallelLinear, we need to
|
||||
# materialize the padded weight parameter for CUDA Graph compatibility.
|
||||
self._create_padded_weight_param(layer)
|
||||
|
||||
def _create_padded_weight_param(self, layer: torch.nn.Module):
|
||||
"""Create padded weight parameter for GGUF MergedLinear layer."""
|
||||
qweight = layer.qweight
|
||||
shard_id_map = qweight.shard_id_map
|
||||
shard_id = qweight.shard_id
|
||||
if len(data_container := qweight.data_container) > 1:
|
||||
dtype = {data.dtype for data in data_container}
|
||||
assert len(dtype) == 1, ValueError(
|
||||
f"Data container has mixed dtypes: {dtype}")
|
||||
dtype = next(iter(dtype))
|
||||
# concat dim0 and pad dim1
|
||||
padded_side = max(x.size(1) for x in data_container)
|
||||
concat_side = sum(x.size(0) for x in data_container)
|
||||
# Pad the quantized weights to dense tensor, and create a map
|
||||
# with the location of each shard in the padded tensor.
|
||||
padded_data = torch.zeros((concat_side, padded_side),
|
||||
dtype=dtype,
|
||||
device=qweight.device)
|
||||
# (dim0_start, dim0_end, dim1_size)
|
||||
shard_offset_map = dict[str, tuple[int, int, int]]()
|
||||
for idx in shard_id:
|
||||
id_in_container = shard_id_map[idx]
|
||||
start = sum(
|
||||
x.size(0) for x in data_container[:id_in_container])
|
||||
end = start + data_container[id_in_container].size(0)
|
||||
size = data_container[id_in_container].size(1)
|
||||
padded_data[start:end, :size] = data_container[id_in_container]
|
||||
shard_offset_map[idx] = (start, end, size)
|
||||
qweight.data_container.clear()
|
||||
padded_param = Parameter(padded_data, requires_grad=False)
|
||||
set_weight_attrs(padded_param, vars(qweight))
|
||||
set_weight_attrs(padded_param,
|
||||
{"shard_offset_map": shard_offset_map})
|
||||
layer.register_parameter("qweight", padded_param)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
shard_id = getattr(layer.qweight, "shard_id", None)
|
||||
shard_id = layer.qweight.shard_id
|
||||
|
||||
if shard_id:
|
||||
# dequantize shard weights respectively
|
||||
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
|
||||
qweight = layer.qweight.unbind(0)
|
||||
qweight = layer.qweight
|
||||
result = []
|
||||
for idx in shard_id:
|
||||
q_idx = layer.qweight.shard_id_map[idx]
|
||||
start, end, offset = layer.qweight.shard_offset_map[idx]
|
||||
qweight_type = layer.qweight_type.shard_weight_type[idx]
|
||||
result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type))
|
||||
result.append(
|
||||
fused_mul_mat_gguf(
|
||||
x, qweight[start:end, :offset].contiguous(),
|
||||
qweight_type))
|
||||
out = torch.cat(result, axis=1)
|
||||
else:
|
||||
qweight = layer.qweight
|
||||
qweight_type = layer.qweight_type.weight_type
|
||||
out = _fuse_mul_mat(x, qweight, qweight_type)
|
||||
out = fused_mul_mat_gguf(x, qweight, qweight_type)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out
|
||||
@@ -338,7 +497,6 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
|
||||
layer.register_parameter("w2_qweight_type", w2_qweight_type)
|
||||
self.act = SiluAndMul()
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -375,10 +533,10 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
return _fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
|
||||
topk_weights, topk_ids,
|
||||
layer.w13_qweight_type.weight_type,
|
||||
layer.w2_qweight_type.weight_type, self.act)
|
||||
return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
|
||||
topk_weights, topk_ids,
|
||||
layer.w13_qweight_type.weight_type,
|
||||
layer.w2_qweight_type.weight_type, activation)
|
||||
|
||||
|
||||
class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
@@ -392,34 +550,15 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
x: torch.Tensor) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
qweight_type = layer.qweight_type.weight_type
|
||||
hidden_size = qweight.tensor_shape[1]
|
||||
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
hidden_size = qweight.shape[1] // type_size * block_size
|
||||
if qweight_type < 2:
|
||||
return torch.embedding(qweight, x)
|
||||
x_flat = x.flatten()
|
||||
quant = torch.index_select(qweight, dim=0, index=x_flat)
|
||||
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
|
||||
x_flat.shape[0], self.params_dtype)
|
||||
return dequant.view(*x.shape, hidden_size)
|
||||
return apply_gguf_embedding(x,
|
||||
qweight,
|
||||
qweight_type,
|
||||
hidden_size,
|
||||
dtype=self.params_dtype)
|
||||
|
||||
|
||||
class GGUFUninitializedParameter(UninitializedParameter):
|
||||
cls_to_become = Parameter
|
||||
data_container: list[torch.Tensor]
|
||||
|
||||
def materialize_nested(self) -> Parameter:
|
||||
dtype = {data.dtype for data in self.data_container}
|
||||
assert len(dtype) == 1, ValueError(
|
||||
f"Data container has mixed dtypes: {dtype}")
|
||||
dtype = next(iter(dtype))
|
||||
nested_data = torch.nested.nested_tensor(self.data_container,
|
||||
device=self.device,
|
||||
dtype=dtype)
|
||||
self.data_container.clear()
|
||||
param = torch.Tensor._make_subclass(self.cls_to_become,
|
||||
nested_data,
|
||||
require_grad=False)
|
||||
for k, v in self.__dict__.items():
|
||||
setattr(param, k, v)
|
||||
return param
|
||||
|
||||
Reference in New Issue
Block a user