Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -13,7 +13,9 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.routing_simulator import (
|
||||
DistributionBasedRouting, RoutingSimulator)
|
||||
DistributionBasedRouting,
|
||||
RoutingSimulator,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -60,10 +62,10 @@ def test_basic_functionality(
|
||||
), f"Wrong ids shape for {strategy}"
|
||||
|
||||
# Check that expert IDs are valid
|
||||
assert (topk_ids.min()
|
||||
>= 0), f"Invalid expert ID (negative) for {strategy}"
|
||||
assert (topk_ids.max()
|
||||
< num_experts), f"Invalid expert ID (too large) for {strategy}"
|
||||
assert topk_ids.min() >= 0, f"Invalid expert ID (negative) for {strategy}"
|
||||
assert topk_ids.max() < num_experts, (
|
||||
f"Invalid expert ID (too large) for {strategy}"
|
||||
)
|
||||
|
||||
|
||||
def test_routing_strategy_integration(monkeypatch, device):
|
||||
@@ -102,19 +104,20 @@ def test_routing_strategy_integration(monkeypatch, device):
|
||||
top_k=top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=True,
|
||||
indices_type=torch.long)
|
||||
indices_type=torch.long,
|
||||
)
|
||||
|
||||
# Verify output shapes
|
||||
assert topk_weights.shape == (
|
||||
num_tokens, top_k), f"Wrong weights shape for {strategy}"
|
||||
assert topk_ids.shape == (num_tokens,
|
||||
top_k), f"Wrong ids shape for {strategy}"
|
||||
assert topk_weights.shape == (num_tokens, top_k), (
|
||||
f"Wrong weights shape for {strategy}"
|
||||
)
|
||||
assert topk_ids.shape == (num_tokens, top_k), f"Wrong ids shape for {strategy}"
|
||||
|
||||
# Verify expert IDs are valid
|
||||
assert topk_ids.min(
|
||||
) >= 0, f"Invalid expert ID (negative) for {strategy}"
|
||||
assert topk_ids.max(
|
||||
) < num_experts, f"Invalid expert ID (too large) for {strategy}"
|
||||
assert topk_ids.min() >= 0, f"Invalid expert ID (negative) for {strategy}"
|
||||
assert topk_ids.max() < num_experts, (
|
||||
f"Invalid expert ID (too large) for {strategy}"
|
||||
)
|
||||
|
||||
|
||||
def test_distribution_based_routing_with_custom_strategy():
|
||||
@@ -123,9 +126,7 @@ def test_distribution_based_routing_with_custom_strategy():
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Register custom distribution-based strategy
|
||||
custom_strategy = DistributionBasedRouting(distribution="normal",
|
||||
mean=2.0,
|
||||
std=0.5)
|
||||
custom_strategy = DistributionBasedRouting(distribution="normal", mean=2.0, std=0.5)
|
||||
RoutingSimulator.register_strategy("custom_normal", custom_strategy)
|
||||
|
||||
# Test data
|
||||
@@ -142,7 +143,8 @@ def test_distribution_based_routing_with_custom_strategy():
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
strategy_name="custom_normal",
|
||||
top_k=top_k)
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
# Check output shapes
|
||||
assert topk_weights.shape == (num_tokens, top_k)
|
||||
@@ -165,7 +167,8 @@ def test_instance_compatibility():
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
strategy_name="uniform_random",
|
||||
top_k=2)
|
||||
top_k=2,
|
||||
)
|
||||
|
||||
assert topk_weights.shape == (10, 2)
|
||||
assert topk_ids.shape == (10, 2)
|
||||
|
||||
Reference in New Issue
Block a user