[Feature] Migrate DeepGEMM API from get_m_alignment_for_contiguous_layout to get_mk_alignment_for_contiguous_layout (#26935)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Wentao Ye
2025-10-16 16:46:48 -04:00
committed by GitHub
parent fb0571b077
commit b3dda72c23
8 changed files with 57 additions and 46 deletions

View File

@@ -12,10 +12,7 @@ from tqdm import tqdm
import vllm.envs as envs
from vllm.distributed.parallel_state import get_dp_group
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
compute_aligned_M,
deep_gemm_block_shape,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
@@ -23,7 +20,11 @@ from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
)
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
get_mk_alignment_for_contiguous_layout,
m_grouped_fp8_gemm_nt_contiguous,
)
def _generate_optimal_warmup_m_values(
@@ -129,7 +130,7 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
"""
Return True if the input module/layer could be processed with DeepGEMM.
"""
block_size = deep_gemm_block_shape()[0]
block_size = get_mk_alignment_for_contiguous_layout()[0]
if not (
isinstance(module, LinearBase)
and isinstance(module.quant_method, Fp8LinearMethod)
@@ -139,7 +140,7 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
w, _, block_sizes = _extract_data_from_linear_base_module(module)
return (
block_sizes == deep_gemm_block_shape()
block_sizes == get_mk_alignment_for_contiguous_layout()
and w.ndim == 2
and w.shape[0] % block_size == 0
and w.shape[1] % block_size == 0
@@ -155,7 +156,7 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
if (
moe_quant_config is None
or moe_quant_config.quant_dtype != torch.float8_e4m3fn
or moe_quant_config.block_shape != deep_gemm_block_shape()
or moe_quant_config.block_shape != get_mk_alignment_for_contiguous_layout()
):
return False
@@ -176,7 +177,7 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
return
n, k = w.size()
block_m = deep_gemm_block_shape()[0]
block_m = get_mk_alignment_for_contiguous_layout()[0]
device = w.device
a1q = torch.empty((max_tokens, k), device=device, dtype=torch.float8_e4m3fn)
@@ -229,7 +230,7 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
block_m = deep_gemm_block_shape()[0]
block_m = get_mk_alignment_for_contiguous_layout()[0]
num_experts = w1.size(0)
device = w1.device