[MoE Refactor][3/N] Deprecate cutlass block quant fp8 (b200) (#30990)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize, _resize_cache
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -896,162 +896,6 @@ def cutlass_moe_fp4(
|
||||
)
|
||||
|
||||
|
||||
def _valid_cutlass_block_scaled_grouped_gemm(
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
inplace: bool,
|
||||
activation: str,
|
||||
apply_router_weight_on_input: bool,
|
||||
expert_map: torch.Tensor | None,
|
||||
) -> bool:
|
||||
def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
|
||||
return N % 128 == 0 and K % 128 == 0
|
||||
|
||||
_, K, N = w2.size()
|
||||
if not _valid_cutlass_block_scaled_grouped_gemm_shape(N, K):
|
||||
logger.debug_once(
|
||||
"CutlassBlockScaledGroupedGemm disabled: unaligned problem size. "
|
||||
"N: %s, K: %s",
|
||||
N,
|
||||
K,
|
||||
)
|
||||
return False
|
||||
|
||||
if w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn:
|
||||
logger.debug_once(
|
||||
"CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s). "
|
||||
"w1.dtype: %s, w2.dtype: %s",
|
||||
w1.dtype,
|
||||
w2.dtype,
|
||||
)
|
||||
return False
|
||||
|
||||
if expert_map is not None:
|
||||
logger.debug_once(
|
||||
"CutlassBlockScaledGroupedGemm disabled: expert_parallel is not supported."
|
||||
)
|
||||
return False
|
||||
|
||||
if activation != "silu":
|
||||
logger.debug_once(
|
||||
"CutlassBlockScaledGroupedGemm disabled: only activation silu is supported."
|
||||
)
|
||||
return False
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
logger.debug_once(
|
||||
"CutlassBlockScaledGroupedGemm disabled:"
|
||||
" apply_router_weight_on_input is not supported."
|
||||
)
|
||||
return False
|
||||
|
||||
if inplace:
|
||||
logger.debug_once(
|
||||
"CutlassBlockScaledGroupedGemm disabled: inplace is not supported."
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# TODO(bnell): would be nice combine/integrate with regular cutlass_fp8.
|
||||
def run_cutlass_block_scaled_fused_experts(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
w1_q = w1.transpose(1, 2)
|
||||
w2_q = w2.transpose(1, 2)
|
||||
w1_scale = w1_scale.transpose(1, 2)
|
||||
w2_scale = w2_scale.transpose(1, 2)
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert a.shape[0] == topk_ids.shape[0], (
|
||||
"a and topk_ids must have the same batch size"
|
||||
)
|
||||
assert w1_q.dtype == torch.float8_e4m3fn, "w1_q must be float8_e4m3fn"
|
||||
assert w2_q.dtype == torch.float8_e4m3fn, "w2_q must be float8_e4m3fn"
|
||||
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
|
||||
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
|
||||
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
||||
assert w1_q.shape[0] == w1_scale.shape[0], "w1_scale expert number mismatch"
|
||||
assert w1_q.shape[0] == w2_scale.shape[0], "w2_scale expert number mismatch"
|
||||
assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
|
||||
|
||||
out_dtype = a.dtype
|
||||
num_experts = w1_q.size(0)
|
||||
m = a.size(0)
|
||||
k = w1_q.size(1)
|
||||
n = w2_q.size(1)
|
||||
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
a_q, a1_scale = _fp8_quantize(
|
||||
a, A_scale=None, per_act_token=False, block_shape=[128, 128]
|
||||
)
|
||||
device = a_q.device
|
||||
|
||||
expert_offsets = torch.empty((num_experts + 1,), dtype=torch.int32, device=device)
|
||||
problem_sizes1 = torch.empty((num_experts, 3), dtype=torch.int32, device=device)
|
||||
problem_sizes2 = torch.empty((num_experts, 3), dtype=torch.int32, device=device)
|
||||
|
||||
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
|
||||
ops.get_cutlass_moe_mm_data(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
a_map,
|
||||
c_map,
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
)
|
||||
|
||||
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
|
||||
rep_a1_scales = a1_scale[a_map]
|
||||
|
||||
c1 = torch.empty((m * topk, n * 2), dtype=out_dtype, device=device)
|
||||
c2 = torch.empty((m * topk, k), dtype=out_dtype, device=device)
|
||||
|
||||
ops.cutlass_blockwise_scaled_grouped_mm(
|
||||
c1,
|
||||
rep_a_q,
|
||||
w1_q,
|
||||
rep_a1_scales,
|
||||
w1_scale,
|
||||
problem_sizes1,
|
||||
expert_offsets[:-1],
|
||||
)
|
||||
|
||||
intermediate = torch.empty((m * topk, n), dtype=out_dtype, device=device)
|
||||
torch.ops._C.silu_and_mul(intermediate, c1)
|
||||
|
||||
intermediate_q, a2_scale = _fp8_quantize(
|
||||
intermediate, A_scale=None, per_act_token=False, block_shape=[128, 128]
|
||||
)
|
||||
|
||||
ops.cutlass_blockwise_scaled_grouped_mm(
|
||||
c2,
|
||||
intermediate_q,
|
||||
w2_q,
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
problem_sizes2,
|
||||
expert_offsets[:-1],
|
||||
)
|
||||
|
||||
return (
|
||||
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
# W4A8
|
||||
def run_cutlass_moe_w4a8_fp8(
|
||||
output: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user