[fix]: disable cutlass block scaled group gemm for EP (#20781)

Signed-off-by: Duncan Moss <djm.moss@gmail.com>
This commit is contained in:
Duncan Moss
2025-07-10 19:39:18 -07:00
committed by GitHub
parent 0cf893cae1
commit 5923ab9524
3 changed files with 34 additions and 9 deletions

View File

@@ -553,8 +553,10 @@ def cutlass_moe_fp4(a: torch.Tensor,
return out.to(dtype=out_dtype)
def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor,
w2: torch.Tensor) -> bool:
def _valid_cutlass_block_scaled_grouped_gemm(
w1: torch.Tensor, w2: torch.Tensor, inplace: bool, activation: str,
apply_router_weight_on_input: bool,
expert_map: Optional[torch.Tensor]) -> bool:
def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
return N % 128 == 0 and K % 128 == 0
@@ -570,6 +572,29 @@ def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor,
"CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s).")
return False
if expert_map is not None:
logger.debug(
"CutlassBlockScaledGroupedGemm disabled: expert_parallel is"
" not supported.")
return False
if activation != "silu":
logger.debug(
"CutlassBlockScaledGroupedGemm disabled: only activation silu is"
" supported.")
return False
if apply_router_weight_on_input:
logger.debug("CutlassBlockScaledGroupedGemm disabled:"
" apply_router_weight_on_input is not supported.")
return False
if inplace:
logger.debug(
"CutlassBlockScaledGroupedGemm disabled: inplace is not supported."
)
return False
return True