Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -6,15 +6,18 @@ import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
from tests.kernels.quantization.nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
if not current_platform.is_device_capability(100):
|
||||
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
|
||||
allow_module_level=True)
|
||||
pytest.skip(
|
||||
"This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True
|
||||
)
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@@ -64,8 +67,9 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_decode_with_baseline(
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
|
||||
Optional[torch.dtype]],
|
||||
quant_dtypes: tuple[
|
||||
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
|
||||
],
|
||||
batch_size: int,
|
||||
max_seq_lens: tuple[int, int],
|
||||
num_heads: tuple[int, int],
|
||||
@@ -106,7 +110,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
|
||||
kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32)
|
||||
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_kv_len
|
||||
|
||||
seq_lens = kv_lens
|
||||
@@ -122,10 +126,9 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
k_scale = v_scale = kv_scale
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
(batch_size, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
@@ -147,20 +150,23 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
|
||||
# Baseline Decode
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, kv_layout, use_tensor_cores=True)
|
||||
wrapper.plan(kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=soft_cap)
|
||||
workspace_buffer, kv_layout, use_tensor_cores=True
|
||||
)
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=soft_cap,
|
||||
)
|
||||
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
wrapper.run(ref_query, ref_kv_cache, out=output)
|
||||
@@ -169,17 +175,21 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
_, o_scale = to_float8(output)
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
|
||||
o_sf_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
|
||||
# TRTLLM Decode
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
|
||||
dtype=torch.uint8),
|
||||
torch.empty((round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4)),
|
||||
dtype=torch.float8_e4m3fn),
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
@@ -201,13 +211,12 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
output_trtllm = output_trtllm.to(dtype) * o_scale
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm.data = output_trtllm.data.reshape(
|
||||
-1, query.shape[1] * query.shape[2] // 2)
|
||||
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
|
||||
output_trtllm.scale,
|
||||
o_sf_scale, dtype,
|
||||
query.device)
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
|
||||
query.shape[2])
|
||||
-1, query.shape[1] * query.shape[2] // 2
|
||||
)
|
||||
output_trtllm = dequantize_nvfp4_to_dtype(
|
||||
output_trtllm.data, output_trtllm.scale, o_sf_scale, dtype, query.device
|
||||
)
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
||||
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
rtol, atol = 3e-1, 1e0
|
||||
@@ -216,8 +225,10 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
else:
|
||||
rtol, atol = 1e-2, 2e-2
|
||||
|
||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}"
|
||||
(
|
||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPE)
|
||||
@@ -233,8 +244,9 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
|
||||
Optional[torch.dtype]],
|
||||
quant_dtypes: tuple[
|
||||
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
|
||||
],
|
||||
batch_size: int,
|
||||
max_seq_lens: tuple[int, int],
|
||||
num_heads: tuple[int, int],
|
||||
@@ -270,17 +282,16 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
|
||||
q_lens = torch.randint(1, max_q_len, (batch_size, ), dtype=torch.int32)
|
||||
q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32)
|
||||
q_lens[-1] = max_q_len
|
||||
q_indptr = torch.cat([
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
||||
])
|
||||
q_indptr = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
|
||||
query = torch.randn(torch.sum(q_lens).item(),
|
||||
num_qo_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
@@ -288,7 +299,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
|
||||
kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32)
|
||||
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_kv_len
|
||||
|
||||
seq_lens = kv_lens + q_lens
|
||||
@@ -304,10 +315,9 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
k_scale = v_scale = kv_scale
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
(batch_size, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
@@ -329,21 +339,24 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
|
||||
# Baseline Prefill
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, kv_layout)
|
||||
wrapper.plan(q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
causal=True,
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=soft_cap)
|
||||
workspace_buffer, kv_layout
|
||||
)
|
||||
wrapper.plan(
|
||||
q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
causal=True,
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=soft_cap,
|
||||
)
|
||||
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
wrapper.run(ref_query, ref_kv_cache, out=output)
|
||||
@@ -352,17 +365,21 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
_, o_scale = to_float8(output)
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
|
||||
o_sf_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
|
||||
# TRTLLM Prefill
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
|
||||
dtype=torch.uint8),
|
||||
torch.empty((round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4)),
|
||||
dtype=torch.float8_e4m3fn),
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
@@ -388,13 +405,12 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
output_trtllm = output_trtllm.to(dtype) * o_scale
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm.data = output_trtllm.data.reshape(
|
||||
-1, query.shape[1] * query.shape[2] // 2)
|
||||
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
|
||||
output_trtllm.scale,
|
||||
o_sf_scale, dtype,
|
||||
query.device)
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
|
||||
query.shape[2])
|
||||
-1, query.shape[1] * query.shape[2] // 2
|
||||
)
|
||||
output_trtllm = dequantize_nvfp4_to_dtype(
|
||||
output_trtllm.data, output_trtllm.scale, o_sf_scale, dtype, query.device
|
||||
)
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
||||
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
rtol, atol = 4e-1, 1e0
|
||||
@@ -405,5 +421,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
else:
|
||||
rtol, atol = 1e-2, 1e-2
|
||||
|
||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}"
|
||||
(
|
||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user