[Feature] Migrate DeepGEMM API from get_m_alignment_for_contiguous_layout to get_mk_alignment_for_contiguous_layout (#26935)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -12,10 +12,7 @@ from tqdm import tqdm
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||
compute_aligned_M,
|
||||
deep_gemm_block_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
@@ -23,7 +20,11 @@ from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
)
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_gemm_nt,
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
m_grouped_fp8_gemm_nt_contiguous,
|
||||
)
|
||||
|
||||
|
||||
def _generate_optimal_warmup_m_values(
|
||||
@@ -129,7 +130,7 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||
"""
|
||||
Return True if the input module/layer could be processed with DeepGEMM.
|
||||
"""
|
||||
block_size = deep_gemm_block_shape()[0]
|
||||
block_size = get_mk_alignment_for_contiguous_layout()[0]
|
||||
if not (
|
||||
isinstance(module, LinearBase)
|
||||
and isinstance(module.quant_method, Fp8LinearMethod)
|
||||
@@ -139,7 +140,7 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||
|
||||
w, _, block_sizes = _extract_data_from_linear_base_module(module)
|
||||
return (
|
||||
block_sizes == deep_gemm_block_shape()
|
||||
block_sizes == get_mk_alignment_for_contiguous_layout()
|
||||
and w.ndim == 2
|
||||
and w.shape[0] % block_size == 0
|
||||
and w.shape[1] % block_size == 0
|
||||
@@ -155,7 +156,7 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||
if (
|
||||
moe_quant_config is None
|
||||
or moe_quant_config.quant_dtype != torch.float8_e4m3fn
|
||||
or moe_quant_config.block_shape != deep_gemm_block_shape()
|
||||
or moe_quant_config.block_shape != get_mk_alignment_for_contiguous_layout()
|
||||
):
|
||||
return False
|
||||
|
||||
@@ -176,7 +177,7 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
|
||||
return
|
||||
|
||||
n, k = w.size()
|
||||
block_m = deep_gemm_block_shape()[0]
|
||||
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
||||
|
||||
device = w.device
|
||||
a1q = torch.empty((max_tokens, k), device=device, dtype=torch.float8_e4m3fn)
|
||||
@@ -229,7 +230,7 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
|
||||
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
|
||||
|
||||
block_m = deep_gemm_block_shape()[0]
|
||||
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
||||
num_experts = w1.size(0)
|
||||
device = w1.device
|
||||
|
||||
|
||||
Reference in New Issue
Block a user