[MoE Refactor][3/N] Deprecate cutlass block quant fp8 (b200) (#30990)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2025-12-19 16:09:54 -05:00
committed by GitHub
parent 5f6477d1d0
commit 83a317f650
8 changed files with 3 additions and 704 deletions

View File

@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize, _resize_cache
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@@ -896,162 +896,6 @@ def cutlass_moe_fp4(
)
def _valid_cutlass_block_scaled_grouped_gemm(
w1: torch.Tensor,
w2: torch.Tensor,
inplace: bool,
activation: str,
apply_router_weight_on_input: bool,
expert_map: torch.Tensor | None,
) -> bool:
def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
return N % 128 == 0 and K % 128 == 0
_, K, N = w2.size()
if not _valid_cutlass_block_scaled_grouped_gemm_shape(N, K):
logger.debug_once(
"CutlassBlockScaledGroupedGemm disabled: unaligned problem size. "
"N: %s, K: %s",
N,
K,
)
return False
if w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn:
logger.debug_once(
"CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s). "
"w1.dtype: %s, w2.dtype: %s",
w1.dtype,
w2.dtype,
)
return False
if expert_map is not None:
logger.debug_once(
"CutlassBlockScaledGroupedGemm disabled: expert_parallel is not supported."
)
return False
if activation != "silu":
logger.debug_once(
"CutlassBlockScaledGroupedGemm disabled: only activation silu is supported."
)
return False
if apply_router_weight_on_input:
logger.debug_once(
"CutlassBlockScaledGroupedGemm disabled:"
" apply_router_weight_on_input is not supported."
)
return False
if inplace:
logger.debug_once(
"CutlassBlockScaledGroupedGemm disabled: inplace is not supported."
)
return False
return True
# TODO(bnell): would be nice combine/integrate with regular cutlass_fp8.
def run_cutlass_block_scaled_fused_experts(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
w1_q = w1.transpose(1, 2)
w2_q = w2.transpose(1, 2)
w1_scale = w1_scale.transpose(1, 2)
w2_scale = w2_scale.transpose(1, 2)
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert a.shape[0] == topk_ids.shape[0], (
"a and topk_ids must have the same batch size"
)
assert w1_q.dtype == torch.float8_e4m3fn, "w1_q must be float8_e4m3fn"
assert w2_q.dtype == torch.float8_e4m3fn, "w2_q must be float8_e4m3fn"
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[0], "w1_scale expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[0], "w2_scale expert number mismatch"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
out_dtype = a.dtype
num_experts = w1_q.size(0)
m = a.size(0)
k = w1_q.size(1)
n = w2_q.size(1)
topk = topk_ids.size(1)
a_q, a1_scale = _fp8_quantize(
a, A_scale=None, per_act_token=False, block_shape=[128, 128]
)
device = a_q.device
expert_offsets = torch.empty((num_experts + 1,), dtype=torch.int32, device=device)
problem_sizes1 = torch.empty((num_experts, 3), dtype=torch.int32, device=device)
problem_sizes2 = torch.empty((num_experts, 3), dtype=torch.int32, device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
ops.get_cutlass_moe_mm_data(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
a_map,
c_map,
num_experts,
n,
k,
)
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
rep_a1_scales = a1_scale[a_map]
c1 = torch.empty((m * topk, n * 2), dtype=out_dtype, device=device)
c2 = torch.empty((m * topk, k), dtype=out_dtype, device=device)
ops.cutlass_blockwise_scaled_grouped_mm(
c1,
rep_a_q,
w1_q,
rep_a1_scales,
w1_scale,
problem_sizes1,
expert_offsets[:-1],
)
intermediate = torch.empty((m * topk, n), dtype=out_dtype, device=device)
torch.ops._C.silu_and_mul(intermediate, c1)
intermediate_q, a2_scale = _fp8_quantize(
intermediate, A_scale=None, per_act_token=False, block_shape=[128, 128]
)
ops.cutlass_blockwise_scaled_grouped_mm(
c2,
intermediate_q,
w2_q,
a2_scale,
w2_scale,
problem_sizes2,
expert_offsets[:-1],
)
return (
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
).sum(dim=1)
# W4A8
def run_cutlass_moe_w4a8_fp8(
output: torch.Tensor,