Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -6,13 +6,22 @@ from unittest import mock
import pytest
import torch
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend)
from tests.v1.attention.utils import (
BatchSpec,
create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.config import (
CacheConfig,
DeviceConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
@@ -23,10 +32,9 @@ mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
"""Create an MTP proposer with unified model configuration."""
model_config = ModelConfig(model=mimo_7b_dir,
runner="generate",
max_model_len=100,
trust_remote_code=True)
model_config = ModelConfig(
model=mimo_7b_dir, runner="generate", max_model_len=100, trust_remote_code=True
)
speculative_config = SpeculativeConfig(
target_model_config=model_config,
@@ -43,17 +51,16 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig())
scheduler_config=SchedulerConfig(),
)
return EagleProposer(vllm_config=vllm_config,
device=current_platform.device_type)
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
mock_get_pp_group):
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_group):
"""Test MTP-specific model loading with unified model approach."""
# Setup mocks
@@ -67,8 +74,10 @@ def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
all_indexer_layers: dict = {}
mock_get_layers.side_effect = [
target_attn_layers, target_indexer_layers, all_attn_layers,
all_indexer_layers
target_attn_layers,
target_indexer_layers,
all_attn_layers,
all_indexer_layers,
]
mock_pp_group = mock.MagicMock()
@@ -116,17 +125,13 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
# MTP returns hidden states directly
if num_speculative_tokens == 1:
model_mock.return_value = torch.zeros(total_tokens,
hidden_size,
device=device)
model_mock.return_value = torch.zeros(total_tokens, hidden_size, device=device)
else:
# Multiple forward passes for multi-token speculation
forward_returns = []
for i in range(num_speculative_tokens):
if i == 0:
h_states = torch.zeros(total_tokens,
hidden_size,
device=device)
h_states = torch.zeros(total_tokens, hidden_size, device=device)
else:
h_states = torch.zeros(batch_size, hidden_size, device=device)
forward_returns.append(h_states)
@@ -140,7 +145,8 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
if num_speculative_tokens == 1:
model_mock.compute_logits.return_value = create_deterministic_logits(
batch_size, vocab_size, 42)
batch_size, vocab_size, 42
)
else:
logits_returns = [
create_deterministic_logits(batch_size, vocab_size, 42 + i)
@@ -153,24 +159,21 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
# Prepare inputs
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
common_attn_metadata = create_common_attn_metadata(batch_spec,
block_size=16,
device=device)
common_attn_metadata = create_common_attn_metadata(
batch_spec, block_size=16, device=device
)
target_token_ids = torch.randint(0,
vocab_size, (total_tokens, ),
device=device)
target_positions = torch.cat([
torch.arange(seq_lens[0], device=device),
torch.arange(seq_lens[1], device=device)
])
target_hidden_states = torch.randn(total_tokens,
hidden_size,
device=device)
next_token_ids = torch.randint(0,
vocab_size, (batch_size, ),
dtype=torch.int32,
device=device)
target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
target_positions = torch.cat(
[
torch.arange(seq_lens[0], device=device),
torch.arange(seq_lens[1], device=device),
]
)
target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
next_token_ids = torch.randint(
0, vocab_size, (batch_size,), dtype=torch.int32, device=device
)
sampling_metadata = mock.MagicMock()
# Setup attention metadata
@@ -187,13 +190,15 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
proposer.attn_metadata_builder = attn_metadata_builder
# Run propose
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
result = proposer.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
)
# Verify the model was called correctly
assert model_mock.called