Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_triton_scaled_mm.py`.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
@@ -15,17 +16,19 @@ from vllm.platforms import current_platform
|
||||
device = "cuda"
|
||||
|
||||
triton_scaled_mm_module = importlib.import_module(
|
||||
"vllm.model_executor.layers.quantization.compressed_tensors."
|
||||
"triton_scaled_mm")
|
||||
"vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm"
|
||||
)
|
||||
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
||||
|
||||
|
||||
def torch_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def torch_scaled_mm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
out = torch.mm(a.to(torch.float32), b.to(torch.float32))
|
||||
out = scale_a * out
|
||||
out = scale_b.T * out
|
||||
@@ -44,20 +47,22 @@ def get_8bit_types():
|
||||
|
||||
|
||||
# This test is to check regressions for int8 support on ROCm.
|
||||
@pytest.mark.parametrize("model_path", [
|
||||
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[
|
||||
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="Should only run on ROCm")
|
||||
def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
|
||||
max_tokens, num_logprobs):
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="Should only run on ROCm")
|
||||
def test_rocm_compressed_tensors_w8a8(
|
||||
vllm_runner, example_prompts, model_path, max_tokens, num_logprobs
|
||||
):
|
||||
dtype = "bfloat16"
|
||||
|
||||
with vllm_runner(model_path, dtype=dtype) as vllm_model:
|
||||
vllm_model.generate_greedy_logprobs(example_prompts, max_tokens,
|
||||
num_logprobs)
|
||||
vllm_model.generate_greedy_logprobs(example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
@@ -76,10 +81,10 @@ MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
|
||||
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
|
||||
use_scalar_scale_b, use_bias):
|
||||
is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t
|
||||
).is_floating_point()
|
||||
def test_scaled_mm(
|
||||
M, N, K, in_dtype, out_dtype, use_scalar_scale_a, use_scalar_scale_b, use_bias
|
||||
):
|
||||
is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t).is_floating_point()
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
@@ -93,10 +98,8 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
|
||||
#
|
||||
# So, the values here are kept small enough to avoid this situation.
|
||||
if is_floating_point_type(in_dtype):
|
||||
a = (0.25 * torch.rand(
|
||||
(M, K), dtype=torch.float32, device=device)).to(in_dtype)
|
||||
b = (0.25 * torch.rand(
|
||||
(K, N), dtype=torch.float32, device=device)).to(in_dtype)
|
||||
a = (0.25 * torch.rand((M, K), dtype=torch.float32, device=device)).to(in_dtype)
|
||||
b = (0.25 * torch.rand((K, N), dtype=torch.float32, device=device)).to(in_dtype)
|
||||
else:
|
||||
a = torch.randint(-32, 32, (M, K), dtype=in_dtype, device=device)
|
||||
b = torch.randint(-32, 32, (K, N), dtype=in_dtype, device=device)
|
||||
@@ -113,7 +116,7 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
|
||||
|
||||
bias = None
|
||||
if use_bias:
|
||||
bias = torch.rand((N, ), device=device, dtype=out_dtype)
|
||||
bias = torch.rand((N,), device=device, dtype=out_dtype)
|
||||
|
||||
c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user