Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -9,11 +9,13 @@ import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config)
|
||||
from vllm.v1.attention.backends.flex_attention import (
|
||||
FlexAttentionMetadataBuilder)
|
||||
from tests.v1.attention.utils import (
|
||||
BatchSpec,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
)
|
||||
from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadataBuilder
|
||||
|
||||
from ..models.utils import check_embeddings_close, check_logprobs_close
|
||||
|
||||
@@ -57,26 +59,32 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
|
||||
set_seed(seed)
|
||||
with vllm_runner(model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True) as llm_flex:
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
) as llm_flex:
|
||||
output_flex = llm_flex.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs)
|
||||
prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
# Run with default backend
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
set_seed(seed)
|
||||
with vllm_runner(model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.85) as llm_default:
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.85,
|
||||
) as llm_default:
|
||||
output_default = llm_default.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs)
|
||||
prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=output_flex,
|
||||
@@ -107,23 +115,27 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
with vllm_runner(model_name,
|
||||
runner="pooling",
|
||||
dtype=torch.bfloat16,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=100,
|
||||
enforce_eager=True) as llm_flex:
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
dtype=torch.bfloat16,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=100,
|
||||
enforce_eager=True,
|
||||
) as llm_flex:
|
||||
flex_outputs = llm_flex.embed(prompts)
|
||||
|
||||
# Run with default backend
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
with vllm_runner(model_name,
|
||||
runner="pooling",
|
||||
dtype=torch.bfloat16,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=100,
|
||||
enforce_eager=True) as llm_default:
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
dtype=torch.bfloat16,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=100,
|
||||
enforce_eager=True,
|
||||
) as llm_default:
|
||||
default_outputs = llm_default.embed(prompts)
|
||||
|
||||
check_embeddings_close(
|
||||
@@ -147,27 +159,29 @@ def test_block_mask_direct_vs_slow_path():
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
|
||||
vllm_config = create_vllm_config(model_name="meta-llama/Meta-Llama-3-8B",
|
||||
block_size=16,
|
||||
max_model_len=1024)
|
||||
vllm_config = create_vllm_config(
|
||||
model_name="meta-llama/Meta-Llama-3-8B", block_size=16, max_model_len=1024
|
||||
)
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
|
||||
# Use a mixed batch that will create groups spanning multiple sequences
|
||||
batch_spec = BatchSpec(seq_lens=[35, 64, 128, 256],
|
||||
query_lens=[33, 5, 32, 64],
|
||||
name="test_mixed_batch")
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[35, 64, 128, 256], query_lens=[33, 5, 32, 64], name="test_mixed_batch"
|
||||
)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, vllm_config.cache_config.block_size, device)
|
||||
batch_spec, vllm_config.cache_config.block_size, device
|
||||
)
|
||||
|
||||
builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config,
|
||||
device)
|
||||
builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, device)
|
||||
|
||||
metadata_direct = builder.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
metadata_direct = builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
builder.direct_build = False
|
||||
metadata_slow = builder.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
metadata_slow = builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
|
||||
assert metadata_direct.block_mask is not None
|
||||
assert metadata_slow.block_mask is not None
|
||||
@@ -184,20 +198,20 @@ def test_block_mask_direct_vs_slow_path():
|
||||
missing_details = []
|
||||
|
||||
for group_idx in range(num_groups):
|
||||
direct_blocks = set(
|
||||
direct_indices[group_idx, :direct_num[group_idx]].tolist())
|
||||
slow_blocks = set(
|
||||
slow_indices[group_idx, :slow_num[group_idx]].tolist())
|
||||
direct_blocks = set(direct_indices[group_idx, : direct_num[group_idx]].tolist())
|
||||
slow_blocks = set(slow_indices[group_idx, : slow_num[group_idx]].tolist())
|
||||
|
||||
missing_blocks = slow_blocks - direct_blocks
|
||||
if missing_blocks:
|
||||
all_contained = False
|
||||
missing_details.append(
|
||||
f"Group {group_idx}: missing {sorted(missing_blocks)}")
|
||||
f"Group {group_idx}: missing {sorted(missing_blocks)}"
|
||||
)
|
||||
|
||||
assert all_contained, (
|
||||
"Direct path is missing blocks required by slow path:\n" +
|
||||
"\n".join(missing_details))
|
||||
"Direct path is missing blocks required by slow path:\n"
|
||||
+ "\n".join(missing_details)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user