Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -11,30 +11,24 @@ from vllm.utils import cdiv
|
||||
|
||||
|
||||
def ref_mla(
|
||||
out: Tensor, # (bs, num_heads, v_head_dim)
|
||||
query: Tensor, # (bs, num_heads, head_dim)
|
||||
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
|
||||
scale: float,
|
||||
block_tables: Tensor, # (bs, max_num_blocks)
|
||||
seq_lens: Tensor, # (bs,)
|
||||
out: Tensor, # (bs, num_heads, v_head_dim)
|
||||
query: Tensor, # (bs, num_heads, head_dim)
|
||||
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
|
||||
scale: float,
|
||||
block_tables: Tensor, # (bs, max_num_blocks)
|
||||
seq_lens: Tensor, # (bs,)
|
||||
):
|
||||
bs, num_heads, v_head_dim = out.shape
|
||||
head_dim = query.shape[2]
|
||||
|
||||
for i in range(bs):
|
||||
# gather and flatten KV-cache
|
||||
kv = kv_cache[
|
||||
block_tables[i]] # (max_num_blocks, block_size, head_dim)
|
||||
kv = kv.view(1, -1,
|
||||
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
|
||||
kv = kv_cache[block_tables[i]] # (max_num_blocks, block_size, head_dim)
|
||||
kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] # (1, seq_len, head_dim)
|
||||
v = kv[:, :, :v_head_dim]
|
||||
|
||||
q = query[i].view(num_heads, 1, head_dim)
|
||||
o = F.scaled_dot_product_attention(q,
|
||||
kv,
|
||||
v,
|
||||
scale=scale,
|
||||
enable_gqa=True)
|
||||
o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True)
|
||||
out[i] = o.view(num_heads, v_head_dim)
|
||||
|
||||
return out
|
||||
@@ -63,18 +57,17 @@ def test_mla_decode_cpu(
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(0)
|
||||
|
||||
scale = d**(-0.5)
|
||||
scale = d ** (-0.5)
|
||||
if varlen:
|
||||
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
|
||||
seq_lens = seq_lens.clip(2).to(torch.int32)
|
||||
else:
|
||||
seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32)
|
||||
seq_lens = torch.full((bs,), mean_seq_len, dtype=torch.int32)
|
||||
max_seq_len = seq_lens.max().item()
|
||||
seqlen_pad = cdiv(max_seq_len, 256) * 256 # is this necessary?
|
||||
|
||||
q = torch.randn(bs, h_q, d)
|
||||
block_table = torch.arange(bs * seqlen_pad // block_size,
|
||||
dtype=torch.int32)
|
||||
block_table = torch.arange(bs * seqlen_pad // block_size, dtype=torch.int32)
|
||||
block_table = block_table.view(bs, seqlen_pad // block_size)
|
||||
|
||||
kv_cache = torch.randn(block_table.numel(), block_size, d)
|
||||
@@ -82,8 +75,7 @@ def test_mla_decode_cpu(
|
||||
kv_cache.view(bs, seqlen_pad, d)[i, seq_len:] = float("nan")
|
||||
|
||||
out_mla = q.new_zeros(bs, h_q, dv)
|
||||
ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table,
|
||||
seq_lens)
|
||||
ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table, seq_lens)
|
||||
|
||||
out_ref = q.new_zeros(bs, h_q, dv)
|
||||
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
|
||||
|
||||
Reference in New Issue
Block a user