Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -6,9 +6,11 @@ from typing import Optional
import torch
from tests.v1.attention.utils import (create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
@@ -42,10 +44,11 @@ def forward_attention(
num_kv_heads = k.shape[-2]
# Initialize the query and KV sequence lengths.
query_start_loc = q_len * torch.arange(
batch_size + 1, device=q.device, dtype=torch.int32)
batch_size + 1, device=q.device, dtype=torch.int32
)
query_lens = torch.diff(query_start_loc)
seq_lens = torch.full(
(batch_size, ),
(batch_size,),
seqlen_k,
device=q.device,
dtype=torch.int32,
@@ -55,14 +58,13 @@ def forward_attention(
max_query_len = q_len
num_actual_tokens = query_start_loc[-1]
softmax_scale = q.shape[-1]**(-0.5)
softmax_scale = q.shape[-1] ** (-0.5)
layer = MockAttentionLayer()
# Build common metadata.
model_name = "meta-llama/Meta-Llama-3-8B"
builder_cls, impl_cls = get_attention_backend(backend)
vllm_config = create_vllm_config(model_name=model_name,
max_model_len=max(seq_lens))
vllm_config = create_vllm_config(model_name=model_name, max_model_len=max(seq_lens))
if spec_token_tree is not None:
# Create speculative config if token tree is specified.
vllm_config.speculative_config = SpeculativeConfig(
@@ -71,7 +73,8 @@ def forward_attention(
model=model_name,
method="eagle",
num_speculative_tokens=num_spec_tokens,
speculative_token_tree=spec_token_tree)
speculative_token_tree=spec_token_tree,
)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
common_attn_metadata = CommonAttentionMetadata(
@@ -128,8 +131,7 @@ def test_tree_attn_correctness() -> None:
device = "cuda"
tree_attn_masks = {
# Chain.
"[(0,), (0, 0), (0, 0, 0)]":
torch.tensor(
"[(0,), (0, 0), (0, 0, 0)]": torch.tensor(
[
[1, 0, 0, 0],
[1, 1, 0, 0],
@@ -140,8 +142,7 @@ def test_tree_attn_correctness() -> None:
dtype=torch.int32,
),
# Tree.
"[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]":
torch.tensor(
"[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]": torch.tensor(
[
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
@@ -202,8 +203,7 @@ def test_tree_attn_correctness() -> None:
device=q.device,
dtype=torch.bfloat16,
)
num_alloc_blocks_per_batch = math.ceil(seqlen_k /
block_size)
num_alloc_blocks_per_batch = math.ceil(seqlen_k / block_size)
block_table = torch.zeros(
(batch_size, max_blocks_per_batch),
device=q.device,
@@ -217,11 +217,10 @@ def test_tree_attn_correctness() -> None:
)
if randomize_blocks:
# Randomize the block ids.
block_ids = block_ids[torch.randperm(
block_ids.numel())]
block_table[:, :
num_alloc_blocks_per_batch] = block_ids.view(
-1, num_alloc_blocks_per_batch)
block_ids = block_ids[torch.randperm(block_ids.numel())]
block_table[:, :num_alloc_blocks_per_batch] = block_ids.view(
-1, num_alloc_blocks_per_batch
)
# Set up the slot mapping for the input KVs.
tree_positions = sequence_position + torch.arange(
@@ -231,7 +230,8 @@ def test_tree_attn_correctness() -> None:
dtype=torch.int64,
).repeat(batch_size, 1)
tree_slot_mapping = _gen_slot_mapping(
tree_positions, block_table, block_size)
tree_positions, block_table, block_size
)
# Compute attention for the tree.
tree_attn_output = forward_attention(
@@ -253,8 +253,7 @@ def test_tree_attn_correctness() -> None:
for q_index in range(tree_size_q):
# Get the q, k, and v for the branch.
branch_mask = tree_attn_mask[q_index, :]
branch_indices = torch.nonzero(branch_mask,
as_tuple=True)[0]
branch_indices = torch.nonzero(branch_mask, as_tuple=True)[0]
q_len = branch_indices.shape[0]
q_branch = q[:, branch_indices]
k_branch = k[:, branch_indices]
@@ -268,7 +267,8 @@ def test_tree_attn_correctness() -> None:
dtype=torch.int64,
).repeat(batch_size, 1)
branch_slot_mapping = _gen_slot_mapping(
branch_positions, block_table, block_size)
branch_positions, block_table, block_size
)
# Compute flash attention for the branch.
flash_attn_output = forward_attention(
@@ -287,16 +287,19 @@ def test_tree_attn_correctness() -> None:
tree_attn_output[:, branch_indices],
flash_attn_output,
atol=7.81e-3,
), (f"outputs are not close for "
), (
f"outputs are not close for "
f"batch_size: {batch_size}, "
f"num_heads: {num_heads}, "
f"sequence_position: {sequence_position}, "
f"tree_attn_mask: {tree_attn_mask}, "
f"q_index: {q_index}.")
f"q_index: {q_index}."
)
def _gen_slot_mapping(positions: torch.Tensor, block_table: torch.Tensor,
block_size: int):
def _gen_slot_mapping(
positions: torch.Tensor, block_table: torch.Tensor, block_size: int
):
block_indices = positions // block_size
blocks = block_table.gather(dim=1, index=block_indices)
return (blocks * block_size + positions % block_size).view(-1)