[Feature] Migrate DeepGEMM API from get_m_alignment_for_contiguous_layout to get_mk_alignment_for_contiguous_layout (#26935)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -6,14 +6,17 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked, is_deep_gemm_e8m0_used
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_m_grouped_gemm_nt_masked,
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -227,7 +230,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
quant_config: Quantization configuration
|
||||
"""
|
||||
super().__init__(quant_config)
|
||||
assert self.block_shape == deep_gemm_block_shape()
|
||||
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
|
||||
from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout
|
||||
|
||||
|
||||
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
@@ -31,7 +31,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.allow_deep_gemm = (
|
||||
allow_deep_gemm
|
||||
and self.quant_config.use_fp8_w8a8
|
||||
and self.block_shape == deep_gemm_block_shape()
|
||||
and self.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||
)
|
||||
|
||||
self.batched_deep_gemm_experts = (
|
||||
|
||||
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||
compute_aligned_M,
|
||||
deep_gemm_block_shape,
|
||||
deepgemm_moe_permute,
|
||||
deepgemm_unpermute_and_reduce,
|
||||
)
|
||||
@@ -28,14 +27,17 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
m_grouped_fp8_gemm_nt_contiguous,
|
||||
)
|
||||
from vllm.utils.functools import run_once
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool:
|
||||
align = deep_gemm_block_shape()[0]
|
||||
align = get_mk_alignment_for_contiguous_layout()[0]
|
||||
return align <= M and N % align == 0 and K % align == 0
|
||||
|
||||
|
||||
@@ -54,7 +56,7 @@ def _valid_deep_gemm(
|
||||
M = hidden_states.size(0)
|
||||
_, K, N = w2.size()
|
||||
|
||||
align = deep_gemm_block_shape()[0]
|
||||
align = get_mk_alignment_for_contiguous_layout()[0]
|
||||
|
||||
if not _valid_deep_gemm_shape(M, N, K):
|
||||
logger.debug_once(
|
||||
@@ -124,7 +126,7 @@ def warmup_deepgemm_gg_contiguous_kernels(
|
||||
|
||||
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
|
||||
|
||||
block_m = deep_gemm_block_shape()[0]
|
||||
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
||||
num_experts = w1.size(0)
|
||||
device = w1.device
|
||||
|
||||
@@ -173,7 +175,7 @@ def warmup_deepgemm_gg_contiguous_kernels(
|
||||
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||
super().__init__(quant_config)
|
||||
assert quant_config.block_shape == deep_gemm_block_shape()
|
||||
assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||
assert quant_config.quant_dtype == torch.float8_e4m3fn
|
||||
assert not quant_config.per_act_token_quant
|
||||
assert not quant_config.per_out_ch_quant
|
||||
@@ -255,7 +257,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
M=topk_ids.size(0),
|
||||
num_topk=topk_ids.size(1),
|
||||
local_num_experts=local_num_experts,
|
||||
alignment=deep_gemm_block_shape()[0],
|
||||
alignment=get_mk_alignment_for_contiguous_layout()[0],
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
)
|
||||
|
||||
@@ -364,7 +366,7 @@ def deep_gemm_moe_fp8(
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=deep_gemm_block_shape(),
|
||||
block_shape=get_mk_alignment_for_contiguous_layout(),
|
||||
)
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
|
||||
@@ -5,23 +5,13 @@ Taken from https://github.com/ModelTC/LightLLM/blob/8ed97c74c18f11505b048b1ba00b
|
||||
and updated to fit vllm needs and terminology.
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import round_up
|
||||
|
||||
|
||||
@functools.cache
|
||||
def deep_gemm_block_shape() -> list[int]:
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
|
||||
block = dg.get_m_alignment_for_contiguous_layout()
|
||||
return [block, block]
|
||||
from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout
|
||||
|
||||
|
||||
def expert_num_tokens_round_up_and_sum(
|
||||
@@ -354,8 +344,7 @@ def deepgemm_moe_permute(
|
||||
H = aq.size(1)
|
||||
device = aq.device
|
||||
|
||||
block_m = deep_gemm_block_shape()[0]
|
||||
block_k = deep_gemm_block_shape()[1]
|
||||
block_m, block_k = get_mk_alignment_for_contiguous_layout()
|
||||
|
||||
M_sum = compute_aligned_M(
|
||||
M=topk_ids.size(0),
|
||||
|
||||
@@ -10,9 +10,11 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm,
|
||||
_valid_deep_gemm_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
)
|
||||
|
||||
|
||||
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
@@ -28,7 +30,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.allow_deep_gemm = (
|
||||
allow_deep_gemm
|
||||
and self.quant_config.use_fp8_w8a8
|
||||
and self.block_shape == deep_gemm_block_shape()
|
||||
and self.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||
)
|
||||
|
||||
self.deep_gemm_expert = (
|
||||
|
||||
Reference in New Issue
Block a user