Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -11,14 +11,18 @@ import math
import pytest
import torch
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils.deep_gemm import (calc_diff, is_deep_gemm_supported,
per_block_cast_to_fp8)
per_token_group_quant_fp8,
)
from vllm.utils.deep_gemm import (
calc_diff,
is_deep_gemm_supported,
per_block_cast_to_fp8,
)
BLOCK_SIZE = [128, 128]
@@ -37,8 +41,10 @@ def make_block_quant_fp8_weights(
w2 shape: (E, K, N)
"""
dtype = torch.bfloat16
fp8_max, fp8_min = torch.finfo(torch.float8_e4m3fn).max, torch.finfo(
torch.float8_e4m3fn).min
fp8_max, fp8_min = (
torch.finfo(torch.float8_e4m3fn).max,
torch.finfo(torch.float8_e4m3fn).min,
)
# bf16 reference weights
w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10
@@ -54,24 +60,16 @@ def make_block_quant_fp8_weights(
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
w1_s = torch.empty(e,
n_tiles_w1,
k_tiles_w1,
device="cuda",
dtype=torch.float32)
w2_s = torch.empty(e,
n_tiles_w2,
k_tiles_w2,
device="cuda",
dtype=torch.float32)
w1_s = torch.empty(e, n_tiles_w1, k_tiles_w1, device="cuda", dtype=torch.float32)
w2_s = torch.empty(e, n_tiles_w2, k_tiles_w2, device="cuda", dtype=torch.float32)
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
block_size=block_size,
use_ue8m0=True)
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
block_size=block_size,
use_ue8m0=True)
w1[i], w1_s[i] = per_block_cast_to_fp8(
w1_bf16[i], block_size=block_size, use_ue8m0=True
)
w2[i], w2_s[i] = per_block_cast_to_fp8(
w2_bf16[i], block_size=block_size, use_ue8m0=True
)
return w1, w2, w1_s, w2_s
@@ -81,18 +79,17 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
Run one (M,N,K) configuration on a single GPU and assert DeepGEMM ==
Triton baseline within tolerance.
"""
tokens_bf16 = torch.randn(
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
tokens_bf16 = (
torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
.clamp_min_(-1)
.clamp_max_(1)
)
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
# expert weight tensors
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
block_size)
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, block_size)
router_logits = torch.randn(m,
num_experts,
device="cuda",
dtype=torch.float32)
router_logits = torch.randn(m, num_experts, device="cuda", dtype=torch.float32)
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
@@ -147,15 +144,14 @@ NUM_EXPERTS = [32]
@pytest.mark.parametrize(("m", "n", "k"), MNKs)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.skipif(not is_deep_gemm_supported(),
reason="Requires deep_gemm kernels")
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
with monkeypatch.context() as mp:
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
_fused_moe_mod = importlib.import_module(
"vllm.model_executor.layers.fused_moe.fused_moe")
"vllm.model_executor.layers.fused_moe.fused_moe"
)
call_counter = {"cnt": 0}
@@ -165,8 +161,7 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
call_counter["cnt"] += 1
return orig_fn(*args, **kwargs)
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
_spy_deep_gemm_moe_fp8)
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", _spy_deep_gemm_moe_fp8)
if topk > num_experts:
pytest.skip(f"topk={topk} > num_experts={num_experts}")
@@ -181,6 +176,7 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
)
# ensure that the DeepGEMM path was indeed taken.
assert call_counter["cnt"] == 1, \
f"DeepGEMM path was not executed during the test. " \
assert call_counter["cnt"] == 1, (
f"DeepGEMM path was not executed during the test. "
f"Call counter: {call_counter['cnt']}"
)