[Kernel] Zero point support in fused MarlinMoE kernel + AWQ Fused MoE (#8973)
Co-authored-by: Dipika <dipikasikka1@gmail.com> Co-authored-by: Dipika Sikka <ds3822@columbia.edu>
This commit is contained in:
@@ -12,6 +12,7 @@ import torch
|
||||
from torch._prims_common import TensorLikeType
|
||||
|
||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
|
||||
make_tensor_with_pad)
|
||||
|
||||
@@ -974,6 +975,50 @@ def fp8_allclose(
|
||||
equal_nan=equal_nan)).item())
|
||||
|
||||
|
||||
# Marlin MoE test utils
|
||||
|
||||
|
||||
def stack_and_dev(tensors: List[torch.Tensor]):
|
||||
dev = tensors[0].device
|
||||
return torch.stack(tensors, dim=0).to(dev)
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref))
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, score, topk):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(
|
||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
def torch_moe_single(a, w, score, topk):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
_, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
for i in range(w.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = a[mask] @ w[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w.shape[1])).sum(dim=1)
|
||||
|
||||
|
||||
# A special version of op check that has a restricted default set of test_utils
|
||||
# and a patched version of allclose that supports fp8 types.
|
||||
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
|
||||
|
||||
Reference in New Issue
Block a user