Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -6,20 +6,30 @@ import pytest
import torch
from vllm.attention import Attention
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig, set_current_vllm_config)
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.config import (
CacheConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, update_environment_variables
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
get_kv_cache_configs)
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
KVCacheTensor,
)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
@@ -35,8 +45,7 @@ def initialize_kv_cache(runner: GPUModelRunner):
"""
attn_spec = FullAttentionSpec(
block_size=BLOCK_SIZE,
num_kv_heads=runner.model_config.get_num_kv_heads(
runner.parallel_config),
num_kv_heads=runner.model_config.get_num_kv_heads(runner.parallel_config),
head_size=runner.model_config.get_head_size(),
dtype=runner.kv_cache_dtype,
)
@@ -58,9 +67,7 @@ def initialize_kv_cache(runner: GPUModelRunner):
device=runner.device,
pin_memory=runner.pin_memory,
vocab_size=runner.model_config.get_vocab_size(),
block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
],
block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size],
)
runner.initialize_attn_backend(kv_cache_config)
@@ -98,8 +105,9 @@ def model_runner():
model_config = vllm_config.model_config
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
head_size = model_config.get_head_size()
vllm_config.compilation_config.static_forward_context[
"layer.0"] = Attention(num_heads, head_size, 0.1)
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(
num_heads, head_size, 0.1
)
runner = GPUModelRunner(vllm_config, DEVICE)
initialize_kv_cache(runner)
return runner
@@ -120,10 +128,11 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_features=[],
sampling_params=SamplingParams(),
pooling_params=None,
block_ids=([0], ),
block_ids=([0],),
num_computed_tokens=0,
lora_request=None,
))
)
)
num_scheduled_tokens[req_id] = 3
total_num_scheduled_tokens += num_scheduled_tokens[req_id]
@@ -150,22 +159,22 @@ def _is_req_added(model_runner, req_id: str) -> bool:
return req_id in model_runner.requests
def _is_sampling_metadata_changed(model_runner,
sampling_metadata_before: SamplingMetadata):
return model_runner.input_batch.sampling_metadata is not (
sampling_metadata_before)
def _is_sampling_metadata_changed(
model_runner, sampling_metadata_before: SamplingMetadata
):
return model_runner.input_batch.sampling_metadata is not (sampling_metadata_before)
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_index = model_runner.input_batch.req_id_to_index[req_id]
block_table = model_runner.input_batch.block_table[0]
req_state = model_runner.requests[req_id]
if block_table.num_blocks_per_row[req_index] != len(
req_state.block_ids[0]):
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids[0]):
return False
num_blocks = block_table.num_blocks_per_row[req_index]
return (block_table.block_table.np[req_index, :num_blocks] ==
req_state.block_ids[0]).all()
return (
block_table.block_table.np[req_index, :num_blocks] == req_state.block_ids[0]
).all()
def test_update_states_new_request(model_runner, dist_init):
@@ -248,7 +257,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
req_ids=[req_id],
resumed_from_preemption=[False],
new_token_ids=[[]],
new_block_ids=([[0]], ),
new_block_ids=([[0]],),
num_computed_tokens=[0],
num_output_tokens=[0],
)
@@ -281,46 +290,58 @@ def test_get_nans_in_logits(model_runner, dist_init):
scheduler_output = _schedule_new_request(*req_ids)
model_runner._update_states(scheduler_output)
logits = torch.tensor([
[1.0, 2.0, 3.0],
[3.0, 2.0, 1.0],
], device=DEVICE)
logits = torch.tensor(
[
[1.0, 2.0, 3.0],
[3.0, 2.0, 1.0],
],
device=DEVICE,
)
result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 0, "req_1": 0}
logits = torch.tensor([
[1.0, float('nan'), 3.0],
[4.0, float('nan'), float('nan')],
],
device=DEVICE)
logits = torch.tensor(
[
[1.0, float("nan"), 3.0],
[4.0, float("nan"), float("nan")],
],
device=DEVICE,
)
result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 1, "req_1": 2}
logits = torch.tensor([
[1.0, 2.0, 3.0],
[4.0, float('nan'), float('nan')],
],
device=DEVICE)
logits = torch.tensor(
[
[1.0, 2.0, 3.0],
[4.0, float("nan"), float("nan")],
],
device=DEVICE,
)
result = model_runner._get_nans_in_logits(logits)
assert result == {"req_0": 0, "req_1": 2}
result = model_runner._get_nans_in_logits(logits=None)
assert result == {"req_0": 0, "req_1": 0}
logits = torch.tensor([
[1.0, float('nan'), 3.0],
], device=DEVICE)
logits = torch.tensor(
[
[1.0, float("nan"), 3.0],
],
device=DEVICE,
)
result = model_runner._get_nans_in_logits(logits)
assert result == {'req_0': 1, 'req_1': 0}
assert result == {"req_0": 1, "req_1": 0}
logits = torch.tensor([
[float('nan'), float('nan'), 2.0],
[1.0, 2.0, 3.0],
[float('nan'), 2.0, 3.0],
],
device=DEVICE)
logits = torch.tensor(
[
[float("nan"), float("nan"), 2.0],
[1.0, 2.0, 3.0],
[float("nan"), 2.0, 3.0],
],
device=DEVICE,
)
result = model_runner._get_nans_in_logits(logits)
assert result == {'req_0': 2, 'req_1': 0}
assert result == {"req_0": 2, "req_1": 0}
def test_update_states_no_changes(model_runner, dist_init):
@@ -398,11 +419,13 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
def test_kv_cache_stride_order(monkeypatch, model_runner):
# This test checks if GPUModelRunner initializes correctly when an attention
# backend enforces a non-default KV cache stride order.
n_heads = model_runner.model_config.get_num_kv_heads(
model_runner.parallel_config)
n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config)
expected_kv_cache_shape = [
2, NUM_BLOCKS, BLOCK_SIZE, n_heads,
model_runner.model_config.get_head_size()
2,
NUM_BLOCKS,
BLOCK_SIZE,
n_heads,
model_runner.model_config.get_head_size(),
]
# TODO mla test
default_stride = tuple(range(5))
@@ -415,8 +438,9 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
# Patch the attention backend class and re-trigger the KV cache creation
for attn_group in model_runner._attn_group_iterator():
attn_backend = attn_group.backend
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
rnd_stride_order)
monkeypatch.setattr(
attn_backend, "get_kv_cache_stride_order", rnd_stride_order
)
model_runner.attn_groups = []
model_runner.kv_caches = []
@@ -448,14 +472,13 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
model_runner_2.update_config({"load_config": {"load_format": "dummy"}})
model_runner_2.load_model() # Initial model loading with dummy weights
assert str(model_runner.get_model().state_dict()) != str(
model_runner_2.get_model().state_dict())
model_runner_2.update_config(
{"load_config": {
"load_format": original_load_format
}})
model_runner_2.get_model().state_dict()
)
model_runner_2.update_config({"load_config": {"load_format": original_load_format}})
model_runner_2.reload_weights() # Load real weights inplace
assert str(model_runner.get_model().state_dict()) == str(
model_runner_2.get_model().state_dict())
model_runner_2.get_model().state_dict()
)
def test_reload_weights_before_load_model(model_runner):
@@ -472,21 +495,19 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
layer_0:
Attention(
layer_0: Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
kv_sharing_target_layer_name=layer_1,
),
layer_1:
Attention(
layer_1: Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
)
),
}
# suppress var not used error
assert fwd_context is not None
@@ -500,22 +521,20 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
with pytest.raises(ValueError, match=error_msg):
fwd_context = {
layer_0:
Attention(
layer_0: Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
),
layer_1:
Attention(
layer_1: Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
# invalid layer: cross_attn.atn doesn't exist!
kv_sharing_target_layer_name=invalid_layer,
)
),
}
# suppress var not used error
assert fwd_context is not None
@@ -530,21 +549,19 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
layer_0:
Attention(
layer_0: Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
),
layer_1:
Attention(
layer_1: Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
kv_sharing_target_layer_name=layer_1,
)
),
}
# suppress var not used error
assert fwd_context is not None
@@ -557,20 +574,18 @@ def test_init_kv_cache_without_kv_sharing():
vllm_config = get_vllm_config()
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
Attention(
layer_0: Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
),
layer_1:
Attention(
layer_1: Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
)
),
}
# suppress var not used error
assert fwd_context is not None
@@ -585,15 +600,15 @@ def test_init_kv_cache_without_kv_sharing():
available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
[available_memory])[0]
kv_cache_config = get_kv_cache_configs(
vllm_config, [kv_cache_spec], [available_memory]
)[0]
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.kv_cache_tensors) == 2
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2
max_context_len =\
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
assert max_context_len == 1310720
@@ -601,8 +616,9 @@ def test_init_kv_cache_without_kv_sharing():
# this will only allocate 2 block worth of memory (2 * 32kb)
kv_cache_config.num_blocks = 1
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
kv_cache_tensor.size = (
kv_cache_spec[kv_cache_tensor.shared_by[0]].page_size_bytes)
kv_cache_tensor.size = kv_cache_spec[
kv_cache_tensor.shared_by[0]
].page_size_bytes
runner.initialize_kv_cache(kv_cache_config)
@@ -625,21 +641,19 @@ def test_init_kv_cache_with_kv_sharing_valid():
vllm_config = get_vllm_config()
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
Attention(
layer_0: Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_0,
),
layer_1:
Attention(
layer_1: Attention(
num_heads=8,
head_size=64,
scale=1.0,
prefix=layer_1,
kv_sharing_target_layer_name="model.layers.0.self_attn.attn",
)
),
}
# suppress var not used error
assert fwd_context is not None
@@ -657,24 +671,23 @@ def test_init_kv_cache_with_kv_sharing_valid():
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
# which is twice as many as without KV sharing
num_expected_blocks = 655360 # 20GB / 32KB
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
[available_memory])[0]
kv_cache_config = get_kv_cache_configs(
vllm_config, [kv_cache_spec], [available_memory]
)[0]
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.kv_cache_tensors) == 1
# Each layer now has twice the available memory for KV cache
# compared to no KV sharing
assert kv_cache_config.kv_cache_tensors[0].size == available_memory
max_context_len =\
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
assert max_context_len == 2 * 1310720
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 1 block worth of memory (32kb)
kv_cache_config.num_blocks = 1
kv_cache_config.kv_cache_tensors[0].size =\
kv_cache_spec[layer_0].page_size_bytes
kv_cache_config.kv_cache_tensors[0].size = kv_cache_spec[layer_0].page_size_bytes
runner.initialize_kv_cache(kv_cache_config)
kv_cache_config_after_init = runner.kv_cache_config
@@ -687,30 +700,30 @@ def test_init_kv_cache_with_kv_sharing_valid():
# check layer 1 added to kv cache group's layer names
assert len(kv_cache_config_after_init.kv_cache_groups) == 1
assert len(kv_cache_config_after_init.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
0] == layer_0
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
1] == layer_1
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[1] == layer_1
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
'''
"""
The GPU model runner creates different views into the
KVCacheTensors for the attention and mamba layers
(via _reshape_kv_cache_tensors function). This test verifies
that the views are compatible: writing a mamba block
will not corrupt an attention block and vice versa
'''
"""
current_platform.seed_everything(42)
update_environment_variables({
'RANK': "0",
'LOCAL_RANK': "0",
'WORLD_SIZE': "1",
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})
update_environment_variables(
{
"RANK": "0",
"LOCAL_RANK": "0",
"WORLD_SIZE": "1",
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
}
)
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=1)
torch.set_default_dtype(torch.float16)
@@ -751,8 +764,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
fwd_context = {}
for key in [layer_0, layer_1]:
fwd_context[key] = Attention(
num_heads=model_config.get_num_attention_heads(
parallel_config),
num_heads=model_config.get_num_attention_heads(parallel_config),
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
scale=1.0,
@@ -760,13 +772,12 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
)
for key in [layer_2, layer_3, layer_4, layer_5]:
fwd_context[key] = MambaMixer2(
hidden_size = hf_config.hidden_size,
ssm_state_size = hf_config.mamba_d_state,
conv_kernel_size = hf_config.mamba_d_conv,
intermediate_size = hf_config.mamba_expand *\
hf_config.hidden_size,
use_conv_bias = hf_config.mamba_conv_bias,
use_bias = hf_config.mamba_proj_bias,
hidden_size=hf_config.hidden_size,
ssm_state_size=hf_config.mamba_d_state,
conv_kernel_size=hf_config.mamba_d_conv,
intermediate_size=hf_config.mamba_expand * hf_config.hidden_size,
use_conv_bias=hf_config.mamba_conv_bias,
use_bias=hf_config.mamba_proj_bias,
n_groups=hf_config.mamba_n_groups,
num_heads=hf_config.mamba_n_heads,
head_dim=hf_config.mamba_d_head,
@@ -781,15 +792,15 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
vllm_ctx = vllm_config.compilation_config.static_forward_context
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
runner = GPUModelRunner(vllm_config, DEVICE)
kv_cache_spec = runner.get_kv_cache_spec()
available_memory = 5 * GiB_bytes
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
[available_memory])[0]
kv_cache_config = get_kv_cache_configs(
vllm_config, [kv_cache_spec], [available_memory]
)[0]
runner.initialize_kv_cache(kv_cache_config)
# random partition of blocks
@@ -798,7 +809,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
num_blocks = kv_cache_config.num_blocks
ind = np.arange(num_blocks)
np.random.shuffle(ind)
blocks0, blocks1 = ind[:(num_blocks // 2)], ind[(num_blocks // 2):]
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
@@ -807,34 +818,40 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
# assert we are using FlashInfer
assert attn_shape[0] == num_blocks
attn_blocks_constant = torch.full((len(blocks0), *attn_shape[1:]),
device=DEVICE,
fill_value=3.33)
conv_blocks_constant = torch.full((len(blocks1), *conv_shape[1:]),
device=DEVICE,
fill_value=6.66)
ssm_blocks_constant = torch.full((len(blocks1), *ssm_shape[1:]),
device=DEVICE,
fill_value=9.99)
attn_blocks_constant = torch.full(
(len(blocks0), *attn_shape[1:]), device=DEVICE, fill_value=3.33
)
conv_blocks_constant = torch.full(
(len(blocks1), *conv_shape[1:]), device=DEVICE, fill_value=6.66
)
ssm_blocks_constant = torch.full(
(len(blocks1), *ssm_shape[1:]), device=DEVICE, fill_value=9.99
)
# fill all attention blocks with constant
for layer in [layer_0, layer_1]:
vllm_ctx[layer].kv_cache[0][
blocks0, :] = attn_blocks_constant.detach().clone()
vllm_ctx[layer].kv_cache[0][blocks0, :] = (
attn_blocks_constant.detach().clone()
)
# fill all mamba blocks with constant
for layer in [layer_2, layer_3, layer_4, layer_5]:
vllm_ctx[layer].kv_cache[0][0][
blocks1, :] = conv_blocks_constant.detach().clone()
vllm_ctx[layer].kv_cache[0][1][
blocks1, :] = ssm_blocks_constant.detach().clone()
vllm_ctx[layer].kv_cache[0][0][blocks1, :] = (
conv_blocks_constant.detach().clone()
)
vllm_ctx[layer].kv_cache[0][1][blocks1, :] = (
ssm_blocks_constant.detach().clone()
)
# verify attention and mamba contents are correct
for layer in [layer_0, layer_1]:
assert torch.equal(vllm_ctx[layer].kv_cache[0][blocks0, :],
attn_blocks_constant)
assert torch.equal(
vllm_ctx[layer].kv_cache[0][blocks0, :], attn_blocks_constant
)
for layer in [layer_2, layer_3, layer_4, layer_5]:
assert torch.equal(vllm_ctx[layer].kv_cache[0][0][blocks1, :],
conv_blocks_constant)
assert torch.equal(vllm_ctx[layer].kv_cache[0][1][blocks1, :],
ssm_blocks_constant)
assert torch.equal(
vllm_ctx[layer].kv_cache[0][0][blocks1, :], conv_blocks_constant
)
assert torch.equal(
vllm_ctx[layer].kv_cache[0][1][blocks1, :], ssm_blocks_constant
)