[Kernel] GGUF MoE kernel (#14613)

Signed-off-by: SzymonOzog <szymon.ozog@aleph-alpha.com>
This commit is contained in:
Szymon Ożóg
2025-03-12 04:33:27 +01:00
committed by GitHub
parent e392d85831
commit e22ee1e7a2
8 changed files with 1070 additions and 25 deletions

View File

@@ -8,7 +8,9 @@ from gguf import GGMLQuantizationType as WeightType
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
@@ -18,6 +20,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class GGUFConfig(QuantizationConfig):
"""Config class for GGUF."""
@@ -119,6 +123,59 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
return y
def _fused_moe_gguf(
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
qweight_type: int,
qweight_type2: int,
act,
) -> torch.Tensor:
out_hidden_states = torch.empty_like(x)
if qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES:
num_tokens, _ = x.shape
E, N, _ = w1.shape
top_k = topk_ids.shape[1]
BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type)
sorted_token_ids, expert_ids, num_tokens_post_padded = \
moe_align_block_size(topk_ids, BLOCK_SIZE, E)
out = ops.ggml_moe_a8(x, w1, sorted_token_ids, expert_ids,
num_tokens_post_padded, qweight_type, N, top_k,
num_tokens)
out = act(out)
out = ops.ggml_moe_a8(out, w2, sorted_token_ids, expert_ids,
num_tokens_post_padded, qweight_type2,
w2.shape[1], 1, num_tokens * top_k)
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
topk_weights.view(num_tokens, top_k, 1))
ops.moe_sum(out, out_hidden_states)
else:
logger.warning_once("There is no support for fast MoE kernel "
"for current quantization method. "
"Falling back to slow implementation. ")
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
inp = x[tok].reshape((1, ) + x.shape[1:])
current_hidden_state = None
for ww, ii in zip(w, idx):
expert_up = w1[ii]
out = _fuse_mul_mat(inp, expert_up, qweight_type)
out = act(out)
expert_down = w2[ii]
current_state = _fuse_mul_mat(out, expert_down,
qweight_type2).mul_(ww)
if current_hidden_state is None:
current_hidden_state = current_state
else:
current_hidden_state.add_(current_state)
out_hidden_states[tok] = current_hidden_state
return out_hidden_states
class GGUFLinearMethod(LinearMethodBase):
"""Linear method for GGUF.
@@ -285,27 +342,10 @@ class GGUFMoEMethod(FusedMoEMethodBase):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
final_hidden_states = torch.empty_like(x)
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
inp = x[tok].reshape((1, ) + x.shape[1:])
current_hidden_state = None
for ww, ii in zip(w, idx):
expert_up = layer.w13_qweight[ii]
out = _fuse_mul_mat(inp, expert_up,
layer.w13_qweight_type.weight_type)
out = self.act(out)
expert_down = layer.w2_qweight[ii]
current_state = _fuse_mul_mat(
out, expert_down,
layer.w2_qweight_type.weight_type).mul_(ww)
if current_hidden_state is None:
current_hidden_state = current_state
else:
current_hidden_state.add_(current_state)
final_hidden_states[tok] = current_hidden_state
return final_hidden_states
return _fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
topk_weights, topk_ids,
layer.w13_qweight_type.weight_type,
layer.w2_qweight_type.weight_type, self.act)
class GGUFEmbeddingMethod(GGUFLinearMethod):