[MoE Refactor][15/N] Apply Refactor to Fp8 (#31415)
This commit is contained in:
@@ -7,17 +7,22 @@ from math import prod
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp8,
|
||||
CutlassExpertsFp8,
|
||||
run_cutlass_moe_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
@@ -150,16 +155,15 @@ class MOETensors8Bit(MOETensors):
|
||||
|
||||
|
||||
def run_with_expert_maps(
|
||||
num_experts: int, num_local_experts: int, **cutlass_moe_kwargs
|
||||
num_experts: int,
|
||||
num_local_experts: int,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
**cutlass_moe_kwargs,
|
||||
):
|
||||
def slice_experts():
|
||||
slice_params = [
|
||||
"w1_q",
|
||||
"w2_q",
|
||||
"ab_strides1",
|
||||
"ab_strides2",
|
||||
"c_strides1",
|
||||
"c_strides2",
|
||||
"w1",
|
||||
"w2",
|
||||
]
|
||||
full_tensors = {
|
||||
k: v
|
||||
@@ -167,8 +171,6 @@ def run_with_expert_maps(
|
||||
if k in slice_params and k in cutlass_moe_kwargs
|
||||
}
|
||||
|
||||
quant_config = cutlass_moe_kwargs["quant_config"]
|
||||
|
||||
for i in range(0, num_experts, num_local_experts):
|
||||
s, e = i, i + num_local_experts
|
||||
|
||||
@@ -187,13 +189,23 @@ def run_with_expert_maps(
|
||||
new_quant_config._w1.scale = quant_config.w1_scale[s:e]
|
||||
new_quant_config._w2.scale = quant_config.w2_scale[s:e]
|
||||
|
||||
cutlass_moe_kwargs["quant_config"] = new_quant_config
|
||||
yield cutlass_moe_kwargs, new_quant_config
|
||||
|
||||
yield cutlass_moe_kwargs
|
||||
|
||||
out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
|
||||
for kwargs in slice_experts():
|
||||
out_tensor = out_tensor + cutlass_moe_fp8(**kwargs)
|
||||
out_tensor = torch.zeros_like(cutlass_moe_kwargs["hidden_states"])
|
||||
for kwargs, new_quant_config in slice_experts():
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
CutlassExpertsFp8(
|
||||
out_dtype=kwargs["hidden_states"].dtype,
|
||||
# NOTE(rob): w2 is shaped as [E, hidden, intermediate]
|
||||
e=kwargs["w2"].shape[0], # type: ignore[union-attr]
|
||||
n=kwargs["w2"].shape[2], # type: ignore[union-attr]
|
||||
k=kwargs["w2"].shape[1], # type: ignore[union-attr]
|
||||
quant_config=new_quant_config,
|
||||
device="cuda",
|
||||
),
|
||||
)
|
||||
out_tensor = out_tensor + kernel(**kwargs)
|
||||
|
||||
return out_tensor
|
||||
|
||||
@@ -230,27 +242,35 @@ def run_8_bit(
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"a": moe_tensors.a,
|
||||
"w1_q": moe_tensors.w1_q, # type: ignore[union-attr]
|
||||
"w2_q": moe_tensors.w2_q, # type: ignore[union-attr]
|
||||
"hidden_states": moe_tensors.a,
|
||||
"w1": moe_tensors.w1_q, # type: ignore[union-attr]
|
||||
"w2": moe_tensors.w2_q, # type: ignore[union-attr]
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"ab_strides1": moe_tensors.ab_strides1,
|
||||
"ab_strides2": moe_tensors.ab_strides2,
|
||||
"c_strides1": moe_tensors.c_strides1,
|
||||
"c_strides2": moe_tensors.c_strides2,
|
||||
"quant_config": quant_config,
|
||||
}
|
||||
|
||||
num_experts = moe_tensors.w1.size(0)
|
||||
with_ep = num_local_experts is not None or num_local_experts == num_experts
|
||||
if not with_ep:
|
||||
return cutlass_moe_fp8(**kwargs)
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
CutlassExpertsFp8(
|
||||
out_dtype=moe_tensors.a.dtype,
|
||||
# NOTE(rob): w2 is shaped as [E, hidden, intermediate]
|
||||
e=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
|
||||
n=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
|
||||
k=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
|
||||
quant_config=quant_config,
|
||||
device="cuda",
|
||||
),
|
||||
)
|
||||
return kernel(**kwargs)
|
||||
|
||||
assert num_local_experts is not None
|
||||
return run_with_expert_maps(
|
||||
num_experts,
|
||||
num_local_experts, # type: ignore[arg-type]
|
||||
quant_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user