[Feature] Integrate SM100 DeepGEMM support (#20087)
This commit is contained in:
@@ -15,13 +15,13 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
|
||||
dg_available = False
|
||||
try:
|
||||
import deep_gemm
|
||||
dg_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
dg_available = has_deep_gemm()
|
||||
|
||||
if dg_available:
|
||||
from deep_gemm import get_m_alignment_for_contiguous_layout
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||
@@ -224,6 +224,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||
@pytest.mark.skipif(is_blackwell_deep_gemm_used(), reason="Not E8M0 scale MOE")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
monkeypatch):
|
||||
@@ -238,8 +239,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
|
||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
block_m = get_m_alignment_for_contiguous_layout()
|
||||
block_size = [block_m, block_m]
|
||||
dtype = torch.bfloat16
|
||||
|
||||
|
||||
Reference in New Issue
Block a user