[Kernel] GGUF MoE kernel (#14613)
Signed-off-by: SzymonOzog <szymon.ozog@aleph-alpha.com>
This commit is contained in:
@@ -8,7 +8,9 @@ from gguf import GGMLQuantizationType as WeightType
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
@@ -18,6 +20,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GGUFConfig(QuantizationConfig):
|
||||
"""Config class for GGUF."""
|
||||
@@ -119,6 +123,59 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
|
||||
return y
|
||||
|
||||
|
||||
def _fused_moe_gguf(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
qweight_type: int,
|
||||
qweight_type2: int,
|
||||
act,
|
||||
) -> torch.Tensor:
|
||||
out_hidden_states = torch.empty_like(x)
|
||||
if qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES:
|
||||
num_tokens, _ = x.shape
|
||||
E, N, _ = w1.shape
|
||||
top_k = topk_ids.shape[1]
|
||||
BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = \
|
||||
moe_align_block_size(topk_ids, BLOCK_SIZE, E)
|
||||
out = ops.ggml_moe_a8(x, w1, sorted_token_ids, expert_ids,
|
||||
num_tokens_post_padded, qweight_type, N, top_k,
|
||||
num_tokens)
|
||||
out = act(out)
|
||||
out = ops.ggml_moe_a8(out, w2, sorted_token_ids, expert_ids,
|
||||
num_tokens_post_padded, qweight_type2,
|
||||
w2.shape[1], 1, num_tokens * top_k)
|
||||
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
|
||||
topk_weights.view(num_tokens, top_k, 1))
|
||||
ops.moe_sum(out, out_hidden_states)
|
||||
else:
|
||||
logger.warning_once("There is no support for fast MoE kernel "
|
||||
"for current quantization method. "
|
||||
"Falling back to slow implementation. ")
|
||||
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
|
||||
inp = x[tok].reshape((1, ) + x.shape[1:])
|
||||
current_hidden_state = None
|
||||
for ww, ii in zip(w, idx):
|
||||
expert_up = w1[ii]
|
||||
|
||||
out = _fuse_mul_mat(inp, expert_up, qweight_type)
|
||||
out = act(out)
|
||||
|
||||
expert_down = w2[ii]
|
||||
current_state = _fuse_mul_mat(out, expert_down,
|
||||
qweight_type2).mul_(ww)
|
||||
if current_hidden_state is None:
|
||||
current_hidden_state = current_state
|
||||
else:
|
||||
current_hidden_state.add_(current_state)
|
||||
out_hidden_states[tok] = current_hidden_state
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
class GGUFLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GGUF.
|
||||
|
||||
@@ -285,27 +342,10 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
final_hidden_states = torch.empty_like(x)
|
||||
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
|
||||
inp = x[tok].reshape((1, ) + x.shape[1:])
|
||||
current_hidden_state = None
|
||||
for ww, ii in zip(w, idx):
|
||||
expert_up = layer.w13_qweight[ii]
|
||||
|
||||
out = _fuse_mul_mat(inp, expert_up,
|
||||
layer.w13_qweight_type.weight_type)
|
||||
out = self.act(out)
|
||||
|
||||
expert_down = layer.w2_qweight[ii]
|
||||
current_state = _fuse_mul_mat(
|
||||
out, expert_down,
|
||||
layer.w2_qweight_type.weight_type).mul_(ww)
|
||||
if current_hidden_state is None:
|
||||
current_hidden_state = current_state
|
||||
else:
|
||||
current_hidden_state.add_(current_state)
|
||||
final_hidden_states[tok] = current_hidden_state
|
||||
return final_hidden_states
|
||||
return _fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
|
||||
topk_weights, topk_ids,
|
||||
layer.w13_qweight_type.weight_type,
|
||||
layer.w2_qweight_type.weight_type, self.act)
|
||||
|
||||
|
||||
class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
|
||||
Reference in New Issue
Block a user