Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -14,25 +14,25 @@ from vllm.platforms import current_platform
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
head_size: int, max_position_embeddings: int,
dtype: torch.dtype, device: torch.device):
def generate_test_data(
num_tokens: int,
num_q_heads: int,
num_kv_heads: int,
head_size: int,
max_position_embeddings: int,
dtype: torch.dtype,
device: torch.device,
):
"""Generate test data for given configuration."""
current_platform.seed_everything(42)
# Create 2D positions (3, num_tokens) for multimodal case
positions = torch.randint(0,
max_position_embeddings // 4, (3, num_tokens),
device=device)
positions = torch.randint(
0, max_position_embeddings // 4, (3, num_tokens), device=device
)
# Create query and key tensors
query = torch.randn(num_tokens,
num_q_heads * head_size,
dtype=dtype,
device=device)
key = torch.randn(num_tokens,
num_kv_heads * head_size,
dtype=dtype,
device=device)
query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device)
key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device)
return positions, query, key
@@ -59,7 +59,8 @@ MODELS_TO_TEST = [
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
]),
],
),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
marks=[
@@ -67,24 +68,33 @@ MODELS_TO_TEST = [
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
]),
],
),
]
num_tokens_list = [11, 8192]
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize("model_info, model_name", [
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
for test_config in MODELS_TO_TEST
])
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
)
@pytest.mark.parametrize(
"model_info, model_name",
[
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
for test_config in MODELS_TO_TEST
],
)
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
dtype: torch.dtype, num_tokens: int):
def test_mrope(
model_name: str,
model_info: MRoPETestInfo,
tp_size: int,
dtype: torch.dtype,
num_tokens: int,
):
atol = model_info.atol
rtol = model_info.rtol
@@ -96,8 +106,11 @@ def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
total_num_heads = config.num_attention_heads
num_heads = total_num_heads // tp_size
num_kv_heads = max(1, total_num_kv_heads // tp_size)
head_dim = (config.head_dim if hasattr(config, "head_dim") else
config.hidden_size // total_num_heads)
head_dim = (
config.head_dim
if hasattr(config, "head_dim")
else config.hidden_size // total_num_heads
)
is_neox_style = True
rope_theta = config.rope_theta
@@ -117,9 +130,9 @@ def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
# create q k v input tensors
# create rotary pos emb input tensors
positions, query, key = generate_test_data(num_tokens, num_heads,
num_kv_heads, head_dim,
max_position, dtype, device)
positions, query, key = generate_test_data(
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
)
query_native, key_native = mrope_helper_class.forward_native(
positions,
@@ -137,19 +150,26 @@ def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
torch.testing.assert_close(key_native, key_cuda, atol=atol, rtol=rtol)
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize("model_info, model_name", [
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
for test_config in MODELS_TO_TEST
])
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
)
@pytest.mark.parametrize(
"model_info, model_name",
[
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
for test_config in MODELS_TO_TEST
],
)
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope_torch_compile_tracing(model_name: str,
model_info: MRoPETestInfo, tp_size: int,
dtype: torch.dtype, num_tokens: int):
def test_mrope_torch_compile_tracing(
model_name: str,
model_info: MRoPETestInfo,
tp_size: int,
dtype: torch.dtype,
num_tokens: int,
):
atol = model_info.atol
rtol = model_info.rtol
@@ -161,8 +181,11 @@ def test_mrope_torch_compile_tracing(model_name: str,
total_num_heads = config.num_attention_heads
num_heads = total_num_heads // tp_size
num_kv_heads = max(1, total_num_kv_heads // tp_size)
head_dim = (config.head_dim if hasattr(config, "head_dim") else
config.hidden_size // total_num_heads)
head_dim = (
config.head_dim
if hasattr(config, "head_dim")
else config.hidden_size // total_num_heads
)
is_neox_style = True
rope_theta = config.rope_theta
max_position = config.max_position_embeddings
@@ -180,16 +203,16 @@ def test_mrope_torch_compile_tracing(model_name: str,
).to(device=device)
# Generate test data
positions, query, key = generate_test_data(num_tokens, num_heads,
num_kv_heads, head_dim,
max_position, dtype, device)
positions, query, key = generate_test_data(
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
)
# Create a wrapper that makes the in-place function appear functional
def functional_forward_cuda(pos, q, k):
"""Wrapper that converts in-place operation to functional style
CUDA Graph does not support in-place operations.
This wrapper creates working copies of the
This wrapper creates working copies of the
input tensors and modifies them.
"""
q_work = q.clone() # Create working copies
@@ -206,11 +229,13 @@ def test_mrope_torch_compile_tracing(model_name: str,
)
try:
compiled_forward_cuda = torch.compile(functional_forward_cuda,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False)
compiled_forward_cuda = torch.compile(
functional_forward_cuda,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False,
)
# Run compiled version
query_compiled_cuda, key_compiled_cuda = compiled_forward_cuda(
@@ -225,25 +250,16 @@ def test_mrope_torch_compile_tracing(model_name: str,
mrope_helper_class.forward_cuda(positions, query_cuda, key_cuda)
# Verify results
torch.testing.assert_close(query_compiled_cuda,
query_cuda,
atol=atol,
rtol=rtol)
torch.testing.assert_close(key_compiled_cuda,
key_cuda,
atol=atol,
rtol=rtol)
torch.testing.assert_close(query_compiled_cuda,
query_native,
atol=atol,
rtol=rtol)
torch.testing.assert_close(key_compiled_cuda,
key_native,
atol=atol,
rtol=rtol)
torch.testing.assert_close(
query_compiled_cuda, query_cuda, atol=atol, rtol=rtol
)
torch.testing.assert_close(key_compiled_cuda, key_cuda, atol=atol, rtol=rtol)
torch.testing.assert_close(
query_compiled_cuda, query_native, atol=atol, rtol=rtol
)
torch.testing.assert_close(key_compiled_cuda, key_native, atol=atol, rtol=rtol)
print("✓ forward_cuda successfully traced with torch.compile inductor")
except Exception as e:
pytest.fail(
f"forward_cuda failed to trace with torch.compile inductor: {e}")
pytest.fail(f"forward_cuda failed to trace with torch.compile inductor: {e}")