Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -6,15 +6,18 @@ import flashinfer
import pytest
import torch
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype,
)
from vllm.platforms import current_platform
from vllm.utils import round_up
if not current_platform.is_device_capability(100):
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
allow_module_level=True)
pytest.skip(
"This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True
)
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = current_platform.fp8_dtype()
@@ -64,8 +67,9 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
@torch.inference_mode
def test_flashinfer_trtllm_decode_with_baseline(
dtype: torch.dtype,
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
Optional[torch.dtype]],
quant_dtypes: tuple[
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
batch_size: int,
max_seq_lens: tuple[int, int],
num_heads: tuple[int, int],
@@ -106,7 +110,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
q_scale = 1.0
ref_query = query
kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32)
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_kv_len
seq_lens = kv_lens
@@ -122,10 +126,9 @@ def test_flashinfer_trtllm_decode_with_baseline(
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(batch_size, max_num_blocks_per_seq),
dtype=torch.int32)
block_tables = torch.randint(
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
)
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
@@ -147,20 +150,23 @@ def test_flashinfer_trtllm_decode_with_baseline(
# Baseline Decode
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, kv_layout, use_tensor_cores=True)
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap)
workspace_buffer, kv_layout, use_tensor_cores=True
)
wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap,
)
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
@@ -169,17 +175,21 @@ def test_flashinfer_trtllm_decode_with_baseline(
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
o_sf_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1)
).to(torch.float32)
# TRTLLM Decode
if o_quant_dtype == FP4_DTYPE:
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
dtype=torch.uint8),
torch.empty((round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4)),
dtype=torch.float8_e4m3fn),
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
torch.empty(
(
round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4),
),
dtype=torch.float8_e4m3fn,
),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
@@ -201,13 +211,12 @@ def test_flashinfer_trtllm_decode_with_baseline(
output_trtllm = output_trtllm.to(dtype) * o_scale
elif o_quant_dtype == FP4_DTYPE:
output_trtllm.data = output_trtllm.data.reshape(
-1, query.shape[1] * query.shape[2] // 2)
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
output_trtllm.scale,
o_sf_scale, dtype,
query.device)
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
query.shape[2])
-1, query.shape[1] * query.shape[2] // 2
)
output_trtllm = dequantize_nvfp4_to_dtype(
output_trtllm.data, output_trtllm.scale, o_sf_scale, dtype, query.device
)
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 3e-1, 1e0
@@ -216,8 +225,10 @@ def test_flashinfer_trtllm_decode_with_baseline(
else:
rtol, atol = 1e-2, 2e-2
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - output_trtllm))}"
(
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - output_trtllm))}",
)
@pytest.mark.parametrize("dtype", DTYPE)
@@ -233,8 +244,9 @@ def test_flashinfer_trtllm_decode_with_baseline(
@torch.inference_mode
def test_flashinfer_trtllm_prefill_with_baseline(
dtype: torch.dtype,
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
Optional[torch.dtype]],
quant_dtypes: tuple[
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
batch_size: int,
max_seq_lens: tuple[int, int],
num_heads: tuple[int, int],
@@ -270,17 +282,16 @@ def test_flashinfer_trtllm_prefill_with_baseline(
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
q_lens = torch.randint(1, max_q_len, (batch_size, ), dtype=torch.int32)
q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32)
q_lens[-1] = max_q_len
q_indptr = torch.cat([
torch.tensor([0], dtype=torch.int32),
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
])
q_indptr = torch.cat(
[
torch.tensor([0], dtype=torch.int32),
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
]
)
query = torch.randn(torch.sum(q_lens).item(),
num_qo_heads,
head_size,
dtype=dtype)
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
@@ -288,7 +299,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
q_scale = 1.0
ref_query = query
kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32)
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_kv_len
seq_lens = kv_lens + q_lens
@@ -304,10 +315,9 @@ def test_flashinfer_trtllm_prefill_with_baseline(
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(batch_size, max_num_blocks_per_seq),
dtype=torch.int32)
block_tables = torch.randint(
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
)
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
@@ -329,21 +339,24 @@ def test_flashinfer_trtllm_prefill_with_baseline(
# Baseline Prefill
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout)
wrapper.plan(q_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
causal=True,
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap)
workspace_buffer, kv_layout
)
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
causal=True,
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap,
)
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
@@ -352,17 +365,21 @@ def test_flashinfer_trtllm_prefill_with_baseline(
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
o_sf_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1)
).to(torch.float32)
# TRTLLM Prefill
if o_quant_dtype == FP4_DTYPE:
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
dtype=torch.uint8),
torch.empty((round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4)),
dtype=torch.float8_e4m3fn),
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
torch.empty(
(
round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4),
),
dtype=torch.float8_e4m3fn,
),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
@@ -388,13 +405,12 @@ def test_flashinfer_trtllm_prefill_with_baseline(
output_trtllm = output_trtllm.to(dtype) * o_scale
elif o_quant_dtype == FP4_DTYPE:
output_trtllm.data = output_trtllm.data.reshape(
-1, query.shape[1] * query.shape[2] // 2)
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
output_trtllm.scale,
o_sf_scale, dtype,
query.device)
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
query.shape[2])
-1, query.shape[1] * query.shape[2] // 2
)
output_trtllm = dequantize_nvfp4_to_dtype(
output_trtllm.data, output_trtllm.scale, o_sf_scale, dtype, query.device
)
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 4e-1, 1e0
@@ -405,5 +421,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
else:
rtol, atol = 1e-2, 1e-2
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - output_trtllm))}"
(
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - output_trtllm))}",
)