Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -7,9 +7,14 @@ import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (_ceil_to_ue8m0, calc_diff, fp8_mqa_logits,
|
||||
fp8_paged_mqa_logits, get_num_sms,
|
||||
get_paged_mqa_logits_metadata)
|
||||
from vllm.utils.deep_gemm import (
|
||||
_ceil_to_ue8m0,
|
||||
calc_diff,
|
||||
fp8_mqa_logits,
|
||||
fp8_paged_mqa_logits,
|
||||
get_num_sms,
|
||||
get_paged_mqa_logits_metadata,
|
||||
)
|
||||
|
||||
|
||||
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -24,17 +29,18 @@ def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
|
||||
device=x.device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
x_fp8[:, :block_size * head_dim] = x_scaled.view(
|
||||
num_blocks, block_size * head_dim).view(dtype=torch.uint8)
|
||||
x_fp8[:,
|
||||
block_size * head_dim:] = sf.view(num_blocks,
|
||||
block_size).view(dtype=torch.uint8)
|
||||
x_fp8[:, : block_size * head_dim] = x_scaled.view(
|
||||
num_blocks, block_size * head_dim
|
||||
).view(dtype=torch.uint8)
|
||||
x_fp8[:, block_size * head_dim :] = sf.view(num_blocks, block_size).view(
|
||||
dtype=torch.uint8
|
||||
)
|
||||
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4)
|
||||
|
||||
|
||||
def per_custom_dims_cast_to_fp8(
|
||||
x: torch.Tensor, dims: tuple,
|
||||
use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
x: torch.Tensor, dims: tuple, use_ue8m0: bool
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
|
||||
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
@@ -69,10 +75,12 @@ def _ref_fp8_mqa_logits(
|
||||
q = q.float()
|
||||
k = k.float()
|
||||
|
||||
mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
|
||||
>= cu_seqlen_ks[:, None])
|
||||
mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
|
||||
< cu_seqlen_ke[:, None])
|
||||
mask_lo = (
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
|
||||
)
|
||||
mask_hi = (
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
|
||||
)
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum("mhd,and->hmn", q, k)
|
||||
@@ -84,14 +92,15 @@ def _ref_fp8_mqa_logits(
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="SM90 and SM100 only")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
|
||||
)
|
||||
def test_deepgemm_fp8_mqa_logits():
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
num_heads, head_dim = 32, 128
|
||||
for seq_len in (512, ):
|
||||
for seq_len_kv in (1024, ):
|
||||
for seq_len in (512,):
|
||||
for seq_len_kv in (1024,):
|
||||
for disable_cp in (False, True):
|
||||
q = torch.randn(
|
||||
seq_len,
|
||||
@@ -100,24 +109,23 @@ def test_deepgemm_fp8_mqa_logits():
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
kv = torch.randn(seq_len_kv,
|
||||
head_dim,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16)
|
||||
weights = torch.randn(seq_len,
|
||||
num_heads,
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
kv = torch.randn(
|
||||
seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
weights = torch.randn(
|
||||
seq_len, num_heads, device="cuda", dtype=torch.float32
|
||||
)
|
||||
|
||||
if disable_cp:
|
||||
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
||||
ke = torch.arange(seq_len, dtype=torch.int,
|
||||
device="cuda") + (seq_len_kv - seq_len)
|
||||
ke = torch.arange(seq_len, dtype=torch.int, device="cuda") + (
|
||||
seq_len_kv - seq_len
|
||||
)
|
||||
else:
|
||||
ks, ke = _generate_cp_test_data(seq_len, seq_len_kv)
|
||||
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False)
|
||||
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False)
|
||||
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
|
||||
|
||||
ref_logits = _ref_fp8_mqa_logits(
|
||||
@@ -157,11 +165,10 @@ def _ref_fp8_paged_mqa_logits(
|
||||
context_lens_list = context_lens.tolist()
|
||||
for i in range(batch_size):
|
||||
context_len = context_lens_list[i]
|
||||
q_offsets = torch.arange(context_len - next_n,
|
||||
context_len,
|
||||
device="cuda")
|
||||
weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose(
|
||||
0, 1).contiguous())
|
||||
q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
|
||||
weight_slice = (
|
||||
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
|
||||
)
|
||||
for block_rk in range(cdiv(context_len, block_size)):
|
||||
block_idx = block_tables[i][block_rk]
|
||||
qx, kx = q[i], kv_cache[block_idx]
|
||||
@@ -170,28 +177,30 @@ def _ref_fp8_paged_mqa_logits(
|
||||
(block_rk + 1) * block_size,
|
||||
device="cuda",
|
||||
)
|
||||
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :]
|
||||
<= q_offsets[:, None])
|
||||
mask = (k_offsets[None, :] < context_len) & (
|
||||
k_offsets[None, :] <= q_offsets[:, None]
|
||||
)
|
||||
s = torch.where(
|
||||
mask[None, :, :],
|
||||
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
|
||||
logits.dtype),
|
||||
logits.dtype
|
||||
),
|
||||
float("-inf"),
|
||||
)
|
||||
s = torch.relu(s) * weight_slice[..., None]
|
||||
s = s.sum(dim=0)
|
||||
logits[
|
||||
i * next_n:(i + 1) * next_n,
|
||||
block_rk * block_size:(block_rk + 1) * block_size,
|
||||
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s,
|
||||
float("-inf"))
|
||||
i * next_n : (i + 1) * next_n,
|
||||
block_rk * block_size : (block_rk + 1) * block_size,
|
||||
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="SM90 and SM100 only")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
|
||||
)
|
||||
def test_deepgemm_fp8_paged_mqa_logits():
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
@@ -199,7 +208,7 @@ def test_deepgemm_fp8_paged_mqa_logits():
|
||||
max_model_len = 4096
|
||||
for batch_size, next_n in [(4, 1), (2, 2)]:
|
||||
for heads, index_dim in [(32, 128)]:
|
||||
for avg_kv in (2048, ):
|
||||
for avg_kv in (2048,):
|
||||
num_blocks, blocksize = max_model_len * 2, 64
|
||||
|
||||
q = torch.randn(
|
||||
@@ -218,12 +227,14 @@ def test_deepgemm_fp8_paged_mqa_logits():
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
context_lens = (torch.randint(int(0.8 * avg_kv),
|
||||
int(1.2 * avg_kv),
|
||||
(batch_size, )).cuda().to(
|
||||
torch.int32))
|
||||
max_block_len = ((context_lens.max().item() + blocksize - 1) //
|
||||
blocksize * blocksize)
|
||||
context_lens = (
|
||||
torch.randint(int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,))
|
||||
.cuda()
|
||||
.to(torch.int32)
|
||||
)
|
||||
max_block_len = (
|
||||
(context_lens.max().item() + blocksize - 1) // blocksize * blocksize
|
||||
)
|
||||
block_tables = torch.zeros(
|
||||
(batch_size, max_block_len),
|
||||
device="cuda",
|
||||
@@ -243,7 +254,8 @@ def test_deepgemm_fp8_paged_mqa_logits():
|
||||
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
|
||||
|
||||
schedule_metadata = get_paged_mqa_logits_metadata(
|
||||
context_lens, blocksize, get_num_sms())
|
||||
context_lens, blocksize, get_num_sms()
|
||||
)
|
||||
logits = fp8_paged_mqa_logits(
|
||||
q_fp8,
|
||||
kv_cache_fp8,
|
||||
@@ -263,15 +275,18 @@ def test_deepgemm_fp8_paged_mqa_logits():
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
positions = (torch.arange(max_model_len,
|
||||
device="cuda").unsqueeze(0).expand(
|
||||
batch_size * next_n, -1))
|
||||
row_indices = (
|
||||
torch.arange(batch_size * next_n, device="cuda") // next_n)
|
||||
positions = (
|
||||
torch.arange(max_model_len, device="cuda")
|
||||
.unsqueeze(0)
|
||||
.expand(batch_size * next_n, -1)
|
||||
)
|
||||
row_indices = torch.arange(batch_size * next_n, device="cuda") // next_n
|
||||
next_n_offset = (
|
||||
torch.arange(batch_size * next_n, device="cuda") % next_n)
|
||||
mask = positions <= (context_lens[row_indices] - next_n +
|
||||
next_n_offset).unsqueeze(1)
|
||||
torch.arange(batch_size * next_n, device="cuda") % next_n
|
||||
)
|
||||
mask = positions <= (
|
||||
context_lens[row_indices] - next_n + next_n_offset
|
||||
).unsqueeze(1)
|
||||
|
||||
logits = logits.masked_fill(~mask, 0)
|
||||
ref_logits = ref_logits.masked_fill(~mask, 0)
|
||||
|
||||
Reference in New Issue
Block a user