[Kernel] Zero point support in fused MarlinMoE kernel + AWQ Fused MoE (#8973)

Co-authored-by: Dipika <dipikasikka1@gmail.com>
Co-authored-by: Dipika Sikka <ds3822@columbia.edu>
This commit is contained in:
ElizaWszola
2024-10-04 20:34:44 +02:00
committed by GitHub
parent 0dcc8cbe5a
commit 05d686432f
23 changed files with 969 additions and 223 deletions

View File

@@ -12,6 +12,7 @@ import torch
from torch._prims_common import TensorLikeType
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
make_tensor_with_pad)
@@ -974,6 +975,50 @@ def fp8_allclose(
equal_nan=equal_nan)).item())
# Marlin MoE test utils
def stack_and_dev(tensors: List[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
def torch_moe(a, w1, w2, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
def torch_moe_single(a, w, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
_, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.view(-1)
for i in range(w.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = a[mask] @ w[i].transpose(0, 1)
return (out.view(B, -1, w.shape[1])).sum(dim=1)
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,