Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" CUTLASS based Fused MoE kernels."""
"""CUTLASS based Fused MoE kernels."""
from typing import Callable, Optional
import torch
@@ -10,13 +11,17 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute, moe_unpermute)
moe_permute,
moe_unpermute,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
_resize_cache)
TopKWeightAndReduceDelegate,
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize, _resize_cache
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@@ -56,20 +61,28 @@ def run_cutlass_moe_fp8(
assert w2.dtype == torch.float8_e4m3fn
assert a1q.size(-1) == w1.size(2), "Hidden size mismatch w1"
assert w1.size(1) == w2.size(2) * 2, "Hidden size mismatch w2"
assert w1_scale.dim() == 1 or w1_scale.size(
1) == 1 or w1_scale.shape[1] == w1.size(1), "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.size(
1) == 1 or w2_scale.shape[1] == w2.size(1), "W2 scale shape mismatch"
assert (
w1_scale.dim() == 1 or w1_scale.size(1) == 1 or w1_scale.shape[1] == w1.size(1)
), "W1 scale shape mismatch"
assert (
w2_scale.dim() == 1 or w2_scale.size(1) == 1 or w2_scale.shape[1] == w2.size(1)
), "W2 scale shape mismatch"
assert w1.size(0) == w2.size(0), "Expert number mismatch"
assert a1q_scale is None or a1q_scale.dim() == 0 or a1q_scale.size(
0) == 1 or a1q_scale.size(
0) == a1q.shape[0], "Input scale shape mismatch"
assert (
a1q_scale is None
or a1q_scale.dim() == 0
or a1q_scale.size(0) == 1
or a1q_scale.size(0) == a1q.shape[0]
), "Input scale shape mismatch"
assert w1.size(0) == w2.size(0), "Weights expert number mismatch"
assert w1.size(0) == w1_scale.size(0), "w1 scales expert number mismatch"
assert w1.size(0) == w2_scale.size(0), "w2 scales expert number mismatch"
assert a2_scale is None or a2_scale.dim() == 0 or a2_scale.size(
0) == 1 or a2_scale.size(
0) == a1q.shape[0], "Intermediate scale shape mismatch"
assert (
a2_scale is None
or a2_scale.dim() == 0
or a2_scale.size(0) == 1
or a2_scale.size(0) == a1q.shape[0]
), "Intermediate scale shape mismatch"
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
if expert_map is not None:
assert expert_num_tokens is None
@@ -97,8 +110,9 @@ def run_cutlass_moe_fp8(
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(expert_map[topk_ids] != -1,
expert_map[topk_ids], -1)
local_topk_ids = torch.where(
expert_map[topk_ids] != -1, expert_map[topk_ids], -1
)
else:
local_topk_ids = topk_ids
@@ -108,35 +122,39 @@ def run_cutlass_moe_fp8(
if use_batched_format:
mm1_out = _resize_cache(workspace13, (local_E * padded_M, N * 2))
act_out = _resize_cache(workspace2, (local_E * padded_M, N))
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(local_E * padded_M, N))
quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (local_E * padded_M, N)
)
mm2_out = _resize_cache(workspace2, (local_E * padded_M, K))
else:
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
(M * topk, K))
a1q_perm = _resize_cache(
workspace2.view(dtype=torch.float8_e4m3fn), (M * topk, K)
)
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
act_out = _resize_cache(workspace2, (M * topk, N))
# original workspace are based on input hidden_states dtype (bf16)
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(M * topk, N))
quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M * topk, N)
)
mm2_out = _resize_cache(workspace2, (M * topk, K))
if use_batched_format:
assert expert_num_tokens is not None
expert_offsets = torch.empty((local_E),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((local_E, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((local_E, 3),
dtype=torch.int32,
device=device)
expert_offsets = torch.empty((local_E), dtype=torch.int32, device=device)
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
ops.get_cutlass_pplx_moe_mm_data(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
local_E, padded_M, N, K)
ops.get_cutlass_pplx_moe_mm_data(
expert_offsets,
problem_sizes1,
problem_sizes2,
expert_num_tokens,
local_E,
padded_M,
N,
K,
)
w1_scale = w1_scale.reshape(w1_scale.size(0), -1)
w2_scale = w2_scale.reshape(w2_scale.size(0), -1)
@@ -146,15 +164,14 @@ def run_cutlass_moe_fp8(
# during offset calculations
expert_offsets = expert_offsets.to(torch.int64)
else:
problem_sizes1 = torch.empty((global_num_experts, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((global_num_experts, 3),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty(
(global_num_experts, 3), dtype=torch.int32, device=device
)
problem_sizes2 = torch.empty(
(global_num_experts, 3), dtype=torch.int32, device=device
)
num_expert = global_num_experts if expert_map is None \
else expert_map.size(0)
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
# permuted a1q reuses workspace2
a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute(
a1q,
@@ -163,12 +180,13 @@ def run_cutlass_moe_fp8(
num_expert,
local_E,
expert_map,
permuted_hidden_states=a1q_perm)
permuted_hidden_states=a1q_perm,
)
expert_offsets = expert_offsets[:-1]
ops.get_cutlass_moe_mm_problem_sizes(local_topk_ids, problem_sizes1,
problem_sizes2,
global_num_experts, N, K)
ops.get_cutlass_moe_mm_problem_sizes(
local_topk_ids, problem_sizes1, problem_sizes2, global_num_experts, N, K
)
if not per_act_token and (expert_map is not None or use_batched_format):
# this is necessary to avoid imprecise scale calculation caused by
@@ -176,38 +194,59 @@ def run_cutlass_moe_fp8(
# this rank handles only partial tokens, or when it is batched .
mm1_out.fill_(0)
ops.cutlass_moe_mm(mm1_out, a1q, w1, a1q_scale, w1_scale, expert_offsets,
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
per_act_token, per_out_ch)
ops.cutlass_moe_mm(
mm1_out,
a1q,
w1,
a1q_scale,
w1_scale,
expert_offsets,
problem_sizes1,
ab_strides1,
ab_strides1,
c_strides1,
per_act_token,
per_out_ch,
)
activation_callable(act_out, mm1_out)
a2q, a2q_scale = ops.scaled_fp8_quant(
act_out,
a2_scale,
use_per_token_if_dynamic=per_act_token,
output=quant_out)
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
)
if expert_map is not None:
mm2_out.fill_(0)
ops.cutlass_moe_mm(mm2_out, a2q, w2, a2q_scale, w2_scale, expert_offsets,
problem_sizes2, ab_strides2, ab_strides2, c_strides2,
per_act_token, per_out_ch)
ops.cutlass_moe_mm(
mm2_out,
a2q,
w2,
a2q_scale,
w2_scale,
expert_offsets,
problem_sizes2,
ab_strides2,
ab_strides2,
c_strides2,
per_act_token,
per_out_ch,
)
if use_batched_format:
output.copy_(mm2_out.reshape(local_E, padded_M, K), non_blocking=True)
else:
# for non-chunking mode the output is resized from workspace13
# so we need to make sure mm2_out uses workspace2.
moe_unpermute(out=output,
permuted_hidden_states=mm2_out,
topk_weights=topk_weights,
inv_permuted_idx=inv_perm)
moe_unpermute(
out=output,
permuted_hidden_states=mm2_out,
topk_weights=topk_weights,
inv_permuted_idx=inv_perm,
)
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
out_dtype: Optional[torch.dtype],
@@ -256,23 +295,40 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
activation_callable = lambda o, i: self.activation(activation, o, i)
use_batched_format = self.activation_formats[
0] == mk.FusedMoEActivationFormat.BatchedExperts
use_batched_format = (
self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
)
in_dtype = hidden_states.dtype
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, self.w1_scale, self.w2_scale,
a1q_scale, a2_scale, self.ab_strides1, self.ab_strides2,
self.c_strides1, self.c_strides2, workspace13, workspace2,
output,
hidden_states,
w1,
w2,
topk_ids,
activation_callable,
global_num_experts,
expert_map,
self.w1_scale,
self.w2_scale,
a1q_scale,
a2_scale,
self.ab_strides1,
self.ab_strides2,
self.c_strides1,
self.c_strides2,
workspace13,
workspace2,
expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
use_batched_format, topk_weights)
self.per_act_token_quant,
self.per_out_ch_quant,
use_batched_format,
topk_weights,
)
class CutlassExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
out_dtype: Optional[torch.dtype],
@@ -293,10 +349,12 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
@property
def activation_formats(
self
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
def supports_chunking(self) -> bool:
return True
@@ -323,12 +381,15 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, max(N // 2, K))
output = (M, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
return (
workspace1,
workspace2,
output,
self.out_dtype if self.out_dtype is not None else a.dtype,
)
class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
max_experts_per_worker: int,
@@ -354,10 +415,12 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
@property
def activation_formats(
self
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
return (
mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts,
)
def supports_chunking(self) -> bool:
return False
@@ -381,13 +444,15 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
padded_M = aq.size(1)
num_dp = self.num_dispatchers
assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
max(N // 2, K))
workspace1 = (self.max_experts_per_worker, padded_M * num_dp, max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, max(N // 2, K))
output = (self.max_experts_per_worker, padded_M, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
return (
workspace1,
workspace2,
output,
self.out_dtype if self.out_dtype is not None else a.dtype,
)
def cutlass_moe_fp8(
@@ -456,18 +521,15 @@ def cutlass_moe_fp8(
assert quant_config is not None
if quant_config.a1_scale is not None:
assert (quant_config.per_act_token_quant ==
quant_config.a1_scale.numel() != 1)
assert quant_config.per_act_token_quant == quant_config.a1_scale.numel() != 1
if quant_config.a2_scale is not None:
assert (quant_config.per_act_token_quant ==
quant_config.a2_scale.numel() != 1)
assert quant_config.per_act_token_quant == quant_config.a2_scale.numel() != 1
assert (quant_config.w1_scale is None
or (quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1)
== w1_q.size(1))))
assert quant_config.w1_scale is None or (
quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1) == w1_q.size(1))
)
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
0)
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
@@ -550,25 +612,30 @@ def run_cutlass_moe_fp4(
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
assert (w1_fp4.ndim == 3 and w2_fp4.ndim == 3 and w1_blockscale.ndim == 3
and w2_blockscale.ndim
== 3), ("All Weights must be of rank 3 for cutlass_moe_fp4")
assert (
w1_fp4.ndim == 3
and w2_fp4.ndim == 3
and w1_blockscale.ndim == 3
and w2_blockscale.ndim == 3
), "All Weights must be of rank 3 for cutlass_moe_fp4"
m_a, k_a = a.shape
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
e_w2, k_w2, half_n_w2 = w2_fp4.shape
assert (e_w1 == e_w2
and e_w1 == e), ("Number of experts must match",
f" between weights. {e_w1}, {e_w2}, {e}")
assert (k_a == half_k_w1 * 2
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in "
"expected `n`")
assert (m == m_a), "input shape mismatch"
assert e_w1 == e_w2 and e_w1 == e, (
"Number of experts must match",
f" between weights. {e_w1}, {e_w2}, {e}",
)
assert k_a == half_k_w1 * 2 and k == k_w2, (
"Hidden size mismatch between a, w1 and w2"
)
assert nx2_w1 == n * 2 and half_n_w2 * 2 == n, "mismatch in expected `n`"
assert m == m_a, "input shape mismatch"
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
assert (topk_weights.size(0) == m and topk_ids.size(0)
== m), ("topk must be provided for each row of a")
assert topk_weights.size(0) == m and topk_ids.size(0) == m, (
"topk must be provided for each row of a"
)
topk = topk_ids.size(1)
out_dtype = a.dtype
num_topk = topk_ids.size(1)
@@ -585,15 +652,25 @@ def run_cutlass_moe_fp4(
if apply_router_weight_on_input:
# TODO: this only works for topK=1, will need to update for topK>1
assert num_topk == 1, \
assert num_topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a.mul_(topk_weights.to(out_dtype))
# problem shapes should have [m, n, k]
# Note that problem sizes are based on logical number of elements.
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, a_map, c_map, e, n, k,
blockscale_offsets)
ops.get_cutlass_moe_mm_data(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
a_map,
c_map,
e,
n,
k,
blockscale_offsets,
)
a = ops.shuffle_rows(a, a_map)
rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
@@ -606,17 +683,34 @@ def run_cutlass_moe_fp4(
c1 = _resize_cache(workspace13, (m * topk, n * 2))
c2 = _resize_cache(workspace2, (m * topk, n))
c3 = _resize_cache(workspace13, (m * topk, k))
ops.cutlass_fp4_moe_mm(c1, rep_a_fp4, w1_fp4, rep_a_blockscale,
w1_blockscale, w1_alphas, problem_sizes1,
expert_offsets[:-1], blockscale_offsets[:-1])
ops.cutlass_fp4_moe_mm(
c1,
rep_a_fp4,
w1_fp4,
rep_a_blockscale,
w1_blockscale,
w1_alphas,
problem_sizes1,
expert_offsets[:-1],
blockscale_offsets[:-1],
)
del rep_a_fp4, rep_a_blockscale
torch.ops._C.silu_and_mul(c2, c1)
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk)
c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk
)
ops.cutlass_fp4_moe_mm(c3, int_fp4, w2_fp4, int_blockscale, w2_blockscale,
w2_alphas, problem_sizes2, expert_offsets[:-1],
blockscale_offsets[:-1])
ops.cutlass_fp4_moe_mm(
c3,
int_fp4,
w2_fp4,
int_blockscale,
w2_blockscale,
w2_alphas,
problem_sizes2,
expert_offsets[:-1],
blockscale_offsets[:-1],
)
del int_fp4, int_blockscale
c3 = ops.shuffle_rows(c3, c_map)
@@ -624,9 +718,12 @@ def run_cutlass_moe_fp4(
assert output.dtype == out_dtype
if not apply_router_weight_on_input:
output.copy_(
(c3.view(m, num_topk, k) *
topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1),
non_blocking=True)
(
c3.view(m, num_topk, k)
* topk_weights.view(m, num_topk, 1).to(out_dtype)
).sum(dim=1),
non_blocking=True,
)
else:
output.copy_(c3.view(m, num_topk, k).sum(dim=1), non_blocking=True)
return
@@ -634,7 +731,6 @@ def run_cutlass_moe_fp4(
# Split into batched and non-batched
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_experts_per_worker: int,
@@ -649,14 +745,18 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
@property
def activation_formats(
self
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
if self.use_batched_format:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
return (
mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts,
)
else:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
def supports_expert_map(self) -> bool:
return False
@@ -691,8 +791,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N)
output = (M, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
return (
workspace1,
workspace2,
output,
self.out_dtype if self.out_dtype is not None else a.dtype,
)
def apply(
self,
@@ -740,21 +844,24 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def cutlass_moe_fp4(
a: torch.Tensor,
w1_fp4: torch.Tensor,
w2_fp4: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
m: int,
n: int,
k: int,
e: int,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False) -> torch.Tensor:
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4.")
a: torch.Tensor,
w1_fp4: torch.Tensor,
w2_fp4: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
m: int,
n: int,
k: int,
e: int,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
assert expert_map is None, (
"Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4."
)
# TODO(bnell): this feels a bit hacky
# NVFP4 requires two levels of quantization, which involves
@@ -799,10 +906,13 @@ def cutlass_moe_fp4(
def _valid_cutlass_block_scaled_grouped_gemm(
w1: torch.Tensor, w2: torch.Tensor, inplace: bool, activation: str,
apply_router_weight_on_input: bool,
expert_map: Optional[torch.Tensor]) -> bool:
w1: torch.Tensor,
w2: torch.Tensor,
inplace: bool,
activation: str,
apply_router_weight_on_input: bool,
expert_map: Optional[torch.Tensor],
) -> bool:
def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
return N % 128 == 0 and K % 128 == 0
@@ -816,7 +926,7 @@ def _valid_cutlass_block_scaled_grouped_gemm(
)
return False
if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn):
if w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn:
logger.debug_once(
"CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s). "
"w1.dtype: %s, w2.dtype: %s",
@@ -827,19 +937,21 @@ def _valid_cutlass_block_scaled_grouped_gemm(
if expert_map is not None:
logger.debug_once(
"CutlassBlockScaledGroupedGemm disabled: expert_parallel is"
" not supported.")
"CutlassBlockScaledGroupedGemm disabled: expert_parallel is not supported."
)
return False
if activation != "silu":
logger.debug_once(
"CutlassBlockScaledGroupedGemm disabled: only activation silu is"
" supported.")
"CutlassBlockScaledGroupedGemm disabled: only activation silu is supported."
)
return False
if apply_router_weight_on_input:
logger.debug_once("CutlassBlockScaledGroupedGemm disabled:"
" apply_router_weight_on_input is not supported.")
logger.debug_once(
"CutlassBlockScaledGroupedGemm disabled:"
" apply_router_weight_on_input is not supported."
)
return False
if inplace:
@@ -867,17 +979,16 @@ def run_cutlass_block_scaled_fused_experts(
w2_scale = w2_scale.transpose(1, 2)
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert a.shape[0] == topk_ids.shape[
0], "a and topk_ids must have the same batch size"
assert a.shape[0] == topk_ids.shape[0], (
"a and topk_ids must have the same batch size"
)
assert w1_q.dtype == torch.float8_e4m3fn, "w1_q must be float8_e4m3fn"
assert w2_q.dtype == torch.float8_e4m3fn, "w2_q must be float8_e4m3fn"
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[
0], "w1_scale expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[
0], "w2_scale expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[0], "w1_scale expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[0], "w2_scale expert number mismatch"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
out_dtype = a.dtype
@@ -888,21 +999,14 @@ def run_cutlass_block_scaled_fused_experts(
topk = topk_ids.size(1)
a_q, a1_scale = _fp8_quantize(a,
A_scale=None,
per_act_token=False,
block_shape=[128, 128])
a_q, a1_scale = _fp8_quantize(
a, A_scale=None, per_act_token=False, block_shape=[128, 128]
)
device = a_q.device
expert_offsets = torch.empty((num_experts + 1, ),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
expert_offsets = torch.empty((num_experts + 1,), dtype=torch.int32, device=device)
problem_sizes1 = torch.empty((num_experts, 3), dtype=torch.int32, device=device)
problem_sizes2 = torch.empty((num_experts, 3), dtype=torch.int32, device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
@@ -938,10 +1042,9 @@ def run_cutlass_block_scaled_fused_experts(
intermediate = torch.empty((m * topk, n), dtype=out_dtype, device=device)
torch.ops._C.silu_and_mul(intermediate, c1)
intermediate_q, a2_scale = _fp8_quantize(intermediate,
A_scale=None,
per_act_token=False,
block_shape=[128, 128])
intermediate_q, a2_scale = _fp8_quantize(
intermediate, A_scale=None, per_act_token=False, block_shape=[128, 128]
)
ops.cutlass_blockwise_scaled_grouped_mm(
c2,
@@ -953,5 +1056,6 @@ def run_cutlass_block_scaled_fused_experts(
expert_offsets[:-1],
)
return (c2[c_map].view(m, topk, k) *
topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
return (
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
).sum(dim=1)