[Misc] Add unit tests for MoE ModularKernel combinations + Profiling utility (#20449)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
6fb162447b
commit
53fa457391
@@ -1072,6 +1072,7 @@ def torch_experts(
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
per_act_token_quant=False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
apply_router_weights_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert (global_num_experts == -1
|
||||
or (global_num_experts == w1.shape[0] and expert_map is None)
|
||||
@@ -1081,11 +1082,17 @@ def torch_experts(
|
||||
M, K = a.shape
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
if apply_router_weights_on_input:
|
||||
assert topk == 1
|
||||
a = a * topk_weight.to(a.dtype)
|
||||
|
||||
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
||||
|
||||
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
a, a_scale = moe_kernel_quantize_input(a, None, quant_dtype,
|
||||
if a1_scale:
|
||||
assert not per_act_token_quant and block_shape is None
|
||||
a, a_scale = moe_kernel_quantize_input(a, a1_scale, quant_dtype,
|
||||
per_act_token_quant, block_shape)
|
||||
|
||||
num_experts = w1.shape[0]
|
||||
@@ -1104,6 +1111,7 @@ def torch_experts(
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||
elif block_shape is not None:
|
||||
# block quantized
|
||||
assert (a_scale is not None and w1_scale is not None
|
||||
and w2_scale is not None)
|
||||
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
|
||||
@@ -1121,15 +1129,27 @@ def torch_experts(
|
||||
assert (a_scale is not None and w1_scale is not None
|
||||
and w2_scale is not None)
|
||||
scales = a_scale if a_scale.numel() == 1 else a_scale[mask]
|
||||
|
||||
tmp1 = a[mask].to(f32) * scales
|
||||
w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
|
||||
tmp1 = tmp1 @ w1_dq
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
tmp1 = (tmp1 @ w1_dq).to(out.dtype)
|
||||
|
||||
tmp2 = SiluAndMul()(tmp1).to(out.dtype)
|
||||
|
||||
tmp2, b_scale = moe_kernel_quantize_input(
|
||||
tmp2, a2_scale, quant_dtype, per_act_token_quant,
|
||||
block_shape)
|
||||
assert b_scale is not None
|
||||
|
||||
tmp2 = tmp2.to(f32) * b_scale
|
||||
w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
|
||||
out[mask] = (tmp2 @ w2_dq).to(out.dtype)
|
||||
|
||||
return (out.view(M, -1, w2.shape[1]).to(f32) *
|
||||
topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype)
|
||||
if apply_router_weights_on_input:
|
||||
return out
|
||||
else:
|
||||
return (out.view(M, -1, w2.shape[1]).to(f32) *
|
||||
topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype)
|
||||
|
||||
|
||||
def torch_moe(a: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user