Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -24,14 +24,12 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
|
||||
num_kv_splits = 8
|
||||
|
||||
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
|
||||
req_to_page = torch.randint(0,
|
||||
CACHE_SIZE // PAGE_SIZE,
|
||||
(B, num_pages_per_batch, 1),
|
||||
device="cuda")
|
||||
req_to_page = torch.randint(
|
||||
0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda"
|
||||
)
|
||||
req_to_token = req_to_page * PAGE_SIZE
|
||||
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
|
||||
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
|
||||
1, 1, -1)
|
||||
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1)
|
||||
req_to_token = req_to_token.view(B, -1)
|
||||
req_to_token = req_to_token[:, :seq_len].contiguous()
|
||||
|
||||
@@ -48,7 +46,7 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
|
||||
|
||||
lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
|
||||
|
||||
b_seq_len = torch.full((B, ), seq_len, device="cuda")
|
||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D_V + 1),
|
||||
|
||||
Reference in New Issue
Block a user