Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -6,13 +6,22 @@ from unittest import mock
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
get_attention_backend)
|
||||
from tests.v1.attention.utils import (
|
||||
BatchSpec,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
get_attention_backend,
|
||||
)
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VllmConfig)
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
SpeculativeConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.platforms import current_platform
|
||||
@@ -23,10 +32,9 @@ mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
|
||||
|
||||
def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
|
||||
"""Create an MTP proposer with unified model configuration."""
|
||||
model_config = ModelConfig(model=mimo_7b_dir,
|
||||
runner="generate",
|
||||
max_model_len=100,
|
||||
trust_remote_code=True)
|
||||
model_config = ModelConfig(
|
||||
model=mimo_7b_dir, runner="generate", max_model_len=100, trust_remote_code=True
|
||||
)
|
||||
|
||||
speculative_config = SpeculativeConfig(
|
||||
target_model_config=model_config,
|
||||
@@ -43,17 +51,16 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
|
||||
device_config=DeviceConfig(device=current_platform.device_type),
|
||||
parallel_config=ParallelConfig(),
|
||||
load_config=LoadConfig(),
|
||||
scheduler_config=SchedulerConfig())
|
||||
scheduler_config=SchedulerConfig(),
|
||||
)
|
||||
|
||||
return EagleProposer(vllm_config=vllm_config,
|
||||
device=current_platform.device_type)
|
||||
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
|
||||
|
||||
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
|
||||
def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
|
||||
mock_get_pp_group):
|
||||
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
|
||||
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
|
||||
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
|
||||
def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_group):
|
||||
"""Test MTP-specific model loading with unified model approach."""
|
||||
|
||||
# Setup mocks
|
||||
@@ -67,8 +74,10 @@ def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
|
||||
all_indexer_layers: dict = {}
|
||||
|
||||
mock_get_layers.side_effect = [
|
||||
target_attn_layers, target_indexer_layers, all_attn_layers,
|
||||
all_indexer_layers
|
||||
target_attn_layers,
|
||||
target_indexer_layers,
|
||||
all_attn_layers,
|
||||
all_indexer_layers,
|
||||
]
|
||||
|
||||
mock_pp_group = mock.MagicMock()
|
||||
@@ -116,17 +125,13 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
|
||||
|
||||
# MTP returns hidden states directly
|
||||
if num_speculative_tokens == 1:
|
||||
model_mock.return_value = torch.zeros(total_tokens,
|
||||
hidden_size,
|
||||
device=device)
|
||||
model_mock.return_value = torch.zeros(total_tokens, hidden_size, device=device)
|
||||
else:
|
||||
# Multiple forward passes for multi-token speculation
|
||||
forward_returns = []
|
||||
for i in range(num_speculative_tokens):
|
||||
if i == 0:
|
||||
h_states = torch.zeros(total_tokens,
|
||||
hidden_size,
|
||||
device=device)
|
||||
h_states = torch.zeros(total_tokens, hidden_size, device=device)
|
||||
else:
|
||||
h_states = torch.zeros(batch_size, hidden_size, device=device)
|
||||
forward_returns.append(h_states)
|
||||
@@ -140,7 +145,8 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
|
||||
|
||||
if num_speculative_tokens == 1:
|
||||
model_mock.compute_logits.return_value = create_deterministic_logits(
|
||||
batch_size, vocab_size, 42)
|
||||
batch_size, vocab_size, 42
|
||||
)
|
||||
else:
|
||||
logits_returns = [
|
||||
create_deterministic_logits(batch_size, vocab_size, 42 + i)
|
||||
@@ -153,24 +159,21 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
|
||||
|
||||
# Prepare inputs
|
||||
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
|
||||
common_attn_metadata = create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, block_size=16, device=device
|
||||
)
|
||||
|
||||
target_token_ids = torch.randint(0,
|
||||
vocab_size, (total_tokens, ),
|
||||
device=device)
|
||||
target_positions = torch.cat([
|
||||
torch.arange(seq_lens[0], device=device),
|
||||
torch.arange(seq_lens[1], device=device)
|
||||
])
|
||||
target_hidden_states = torch.randn(total_tokens,
|
||||
hidden_size,
|
||||
device=device)
|
||||
next_token_ids = torch.randint(0,
|
||||
vocab_size, (batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
|
||||
target_positions = torch.cat(
|
||||
[
|
||||
torch.arange(seq_lens[0], device=device),
|
||||
torch.arange(seq_lens[1], device=device),
|
||||
]
|
||||
)
|
||||
target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
|
||||
next_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size,), dtype=torch.int32, device=device
|
||||
)
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
# Setup attention metadata
|
||||
@@ -187,13 +190,15 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
|
||||
proposer.attn_metadata_builder = attn_metadata_builder
|
||||
|
||||
# Run propose
|
||||
result = proposer.propose(target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=None,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata)
|
||||
result = proposer.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=None,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
# Verify the model was called correctly
|
||||
assert model_mock.called
|
||||
|
||||
Reference in New Issue
Block a user