Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,16 +5,17 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
|
||||
native_w8a8_block_matmul)
|
||||
from tests.kernels.quant_utils import (
|
||||
native_per_token_group_quant_int8,
|
||||
native_w8a8_block_matmul,
|
||||
)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
|
||||
allow_module_level=True)
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
@@ -77,24 +78,18 @@ def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_matmul(a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
inter_out = native_w8a8_block_matmul(
|
||||
a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_int8(
|
||||
act_out, block_k)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_int8(act_out, block_k)
|
||||
act_out = act_out.to(torch.float32)
|
||||
out[mask] = native_w8a8_block_matmul(act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
out[mask] = native_w8a8_block_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
@@ -131,15 +126,19 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
out = fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, quant_config.w1_scale,
|
||||
quant_config.w2_scale, score, topk,
|
||||
block_size)
|
||||
out = fused_experts(
|
||||
a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
ref_out = torch_w8a8_block_int8_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
quant_config.w1_scale,
|
||||
quant_config.w2_scale,
|
||||
score,
|
||||
topk,
|
||||
block_size,
|
||||
)
|
||||
|
||||
# Check results
|
||||
torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065)
|
||||
|
||||
Reference in New Issue
Block a user