Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,15 +5,21 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
|
||||
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul)
|
||||
from tests.kernels.quant_utils import (
|
||||
native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul,
|
||||
)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
|
||||
_valid_deep_gemm_shape,
|
||||
deep_gemm_moe_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
fused_topk,
|
||||
modular_triton_fused_moe,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
@@ -24,8 +30,7 @@ if dg_available:
|
||||
from deep_gemm import get_m_alignment_for_contiguous_layout
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||
allow_module_level=True)
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
@@ -97,8 +102,7 @@ TOP_KS = [1, 2, 6]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
|
||||
block_shape):
|
||||
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape):
|
||||
"""Fused moe with block-wise quantization using native torch."""
|
||||
B, D = a.shape
|
||||
topk = topk_ids.size(1)
|
||||
@@ -114,23 +118,17 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_matmul(a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
inter_out = native_w8a8_block_matmul(
|
||||
a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_fp8(
|
||||
act_out, block_k)
|
||||
out[mask] = native_w8a8_block_matmul(act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_fp8(act_out, block_k)
|
||||
out[mask] = native_w8a8_block_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
# Skip all tests if CUDA is not available
|
||||
@@ -149,8 +147,9 @@ def setup_cuda():
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
monkeypatch):
|
||||
def test_w8a8_block_fp8_fused_moe(
|
||||
M, N, K, E, topk, block_size, dtype, seed, monkeypatch
|
||||
):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test; topk={topk} > E={E}")
|
||||
|
||||
@@ -188,12 +187,9 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
block_size,
|
||||
)
|
||||
|
||||
out = fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
out = fused_experts(
|
||||
a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
|
||||
m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
|
||||
|
||||
@@ -210,8 +206,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||
@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
monkeypatch):
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test: topk={topk} > E={E}")
|
||||
|
||||
@@ -245,36 +240,38 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
# setup code in case we are able to revisit this later.
|
||||
use_compile = False
|
||||
|
||||
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
|
||||
and current_platform.is_cuda_alike())
|
||||
use_cudagraph = (
|
||||
chunk_size < M and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||
topk_ids, block_size)
|
||||
ref_out = torch_w8a8_block_fp8_moe(
|
||||
a, w1, w2, w1_s, w2_s, topk_weights, topk_ids, block_size
|
||||
)
|
||||
|
||||
if use_compile:
|
||||
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
|
||||
backend="inductor",
|
||||
fullgraph=True)
|
||||
deep_gemm_moe_fp8_fn = torch.compile(
|
||||
deep_gemm_moe_fp8, backend="inductor", fullgraph=True
|
||||
)
|
||||
torch._dynamo.mark_dynamic(a, 0)
|
||||
torch._dynamo.mark_dynamic(topk_weights, 0)
|
||||
torch._dynamo.mark_dynamic(topk_ids, 0)
|
||||
else:
|
||||
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
|
||||
|
||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||
topk_ids)
|
||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
|
||||
|
||||
if use_cudagraph:
|
||||
out.fill_(0)
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||
topk_ids)
|
||||
out = deep_gemm_moe_fp8_fn(
|
||||
a, w1, w2, w1_s, w2_s, topk_weights, topk_ids
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
Reference in New Issue
Block a user