Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -7,9 +7,14 @@ import torch
from vllm.platforms import current_platform
from vllm.utils import cdiv, has_deep_gemm
from vllm.utils.deep_gemm import (_ceil_to_ue8m0, calc_diff, fp8_mqa_logits,
fp8_paged_mqa_logits, get_num_sms,
get_paged_mqa_logits_metadata)
from vllm.utils.deep_gemm import (
_ceil_to_ue8m0,
calc_diff,
fp8_mqa_logits,
fp8_paged_mqa_logits,
get_num_sms,
get_paged_mqa_logits_metadata,
)
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
@@ -24,17 +29,18 @@ def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
device=x.device,
dtype=torch.uint8,
)
x_fp8[:, :block_size * head_dim] = x_scaled.view(
num_blocks, block_size * head_dim).view(dtype=torch.uint8)
x_fp8[:,
block_size * head_dim:] = sf.view(num_blocks,
block_size).view(dtype=torch.uint8)
x_fp8[:, : block_size * head_dim] = x_scaled.view(
num_blocks, block_size * head_dim
).view(dtype=torch.uint8)
x_fp8[:, block_size * head_dim :] = sf.view(num_blocks, block_size).view(
dtype=torch.uint8
)
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4)
def per_custom_dims_cast_to_fp8(
x: torch.Tensor, dims: tuple,
use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
x: torch.Tensor, dims: tuple, use_ue8m0: bool
) -> tuple[torch.Tensor, torch.Tensor]:
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
@@ -69,10 +75,12 @@ def _ref_fp8_mqa_logits(
q = q.float()
k = k.float()
mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
>= cu_seqlen_ks[:, None])
mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
< cu_seqlen_ke[:, None])
mask_lo = (
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
)
mask_hi = (
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
)
mask = mask_lo & mask_hi
score = torch.einsum("mhd,and->hmn", q, k)
@@ -84,14 +92,15 @@ def _ref_fp8_mqa_logits(
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="SM90 and SM100 only")
@pytest.mark.skipif(
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
)
def test_deepgemm_fp8_mqa_logits():
torch.manual_seed(0)
random.seed(0)
num_heads, head_dim = 32, 128
for seq_len in (512, ):
for seq_len_kv in (1024, ):
for seq_len in (512,):
for seq_len_kv in (1024,):
for disable_cp in (False, True):
q = torch.randn(
seq_len,
@@ -100,24 +109,23 @@ def test_deepgemm_fp8_mqa_logits():
device="cuda",
dtype=torch.bfloat16,
)
kv = torch.randn(seq_len_kv,
head_dim,
device="cuda",
dtype=torch.bfloat16)
weights = torch.randn(seq_len,
num_heads,
device="cuda",
dtype=torch.float32)
kv = torch.randn(
seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16
)
weights = torch.randn(
seq_len, num_heads, device="cuda", dtype=torch.float32
)
if disable_cp:
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
ke = torch.arange(seq_len, dtype=torch.int,
device="cuda") + (seq_len_kv - seq_len)
ke = torch.arange(seq_len, dtype=torch.int, device="cuda") + (
seq_len_kv - seq_len
)
else:
ks, ke = _generate_cp_test_data(seq_len, seq_len_kv)
q_fp8 = q.to(torch.float8_e4m3fn)
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False)
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False)
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
ref_logits = _ref_fp8_mqa_logits(
@@ -157,11 +165,10 @@ def _ref_fp8_paged_mqa_logits(
context_lens_list = context_lens.tolist()
for i in range(batch_size):
context_len = context_lens_list[i]
q_offsets = torch.arange(context_len - next_n,
context_len,
device="cuda")
weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose(
0, 1).contiguous())
q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
weight_slice = (
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
)
for block_rk in range(cdiv(context_len, block_size)):
block_idx = block_tables[i][block_rk]
qx, kx = q[i], kv_cache[block_idx]
@@ -170,28 +177,30 @@ def _ref_fp8_paged_mqa_logits(
(block_rk + 1) * block_size,
device="cuda",
)
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :]
<= q_offsets[:, None])
mask = (k_offsets[None, :] < context_len) & (
k_offsets[None, :] <= q_offsets[:, None]
)
s = torch.where(
mask[None, :, :],
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
logits.dtype),
logits.dtype
),
float("-inf"),
)
s = torch.relu(s) * weight_slice[..., None]
s = s.sum(dim=0)
logits[
i * next_n:(i + 1) * next_n,
block_rk * block_size:(block_rk + 1) * block_size,
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s,
float("-inf"))
i * next_n : (i + 1) * next_n,
block_rk * block_size : (block_rk + 1) * block_size,
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
return logits
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="SM90 and SM100 only")
@pytest.mark.skipif(
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
)
def test_deepgemm_fp8_paged_mqa_logits():
torch.manual_seed(0)
random.seed(0)
@@ -199,7 +208,7 @@ def test_deepgemm_fp8_paged_mqa_logits():
max_model_len = 4096
for batch_size, next_n in [(4, 1), (2, 2)]:
for heads, index_dim in [(32, 128)]:
for avg_kv in (2048, ):
for avg_kv in (2048,):
num_blocks, blocksize = max_model_len * 2, 64
q = torch.randn(
@@ -218,12 +227,14 @@ def test_deepgemm_fp8_paged_mqa_logits():
dtype=torch.float32,
)
context_lens = (torch.randint(int(0.8 * avg_kv),
int(1.2 * avg_kv),
(batch_size, )).cuda().to(
torch.int32))
max_block_len = ((context_lens.max().item() + blocksize - 1) //
blocksize * blocksize)
context_lens = (
torch.randint(int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,))
.cuda()
.to(torch.int32)
)
max_block_len = (
(context_lens.max().item() + blocksize - 1) // blocksize * blocksize
)
block_tables = torch.zeros(
(batch_size, max_block_len),
device="cuda",
@@ -243,7 +254,8 @@ def test_deepgemm_fp8_paged_mqa_logits():
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
schedule_metadata = get_paged_mqa_logits_metadata(
context_lens, blocksize, get_num_sms())
context_lens, blocksize, get_num_sms()
)
logits = fp8_paged_mqa_logits(
q_fp8,
kv_cache_fp8,
@@ -263,15 +275,18 @@ def test_deepgemm_fp8_paged_mqa_logits():
max_model_len,
)
positions = (torch.arange(max_model_len,
device="cuda").unsqueeze(0).expand(
batch_size * next_n, -1))
row_indices = (
torch.arange(batch_size * next_n, device="cuda") // next_n)
positions = (
torch.arange(max_model_len, device="cuda")
.unsqueeze(0)
.expand(batch_size * next_n, -1)
)
row_indices = torch.arange(batch_size * next_n, device="cuda") // next_n
next_n_offset = (
torch.arange(batch_size * next_n, device="cuda") % next_n)
mask = positions <= (context_lens[row_indices] - next_n +
next_n_offset).unsqueeze(1)
torch.arange(batch_size * next_n, device="cuda") % next_n
)
mask = positions <= (
context_lens[row_indices] - next_n + next_n_offset
).unsqueeze(1)
logits = logits.masked_fill(~mask, 0)
ref_logits = ref_logits.masked_fill(~mask, 0)