[Attention][UX][1/N] Add AttentionConfig and change attention env vars to CLI arguments (#26315)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2025-12-05 12:48:43 -05:00
committed by GitHub
parent dff0a2b394
commit 66e674cdd5
22 changed files with 367 additions and 325 deletions

View File

@@ -6,8 +6,10 @@ import pytest
import torch
from vllm.attention.backends.abstract import MultipleOf
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import Attention
from vllm.config import (
AttentionConfig,
CacheConfig,
ModelConfig,
ParallelConfig,
@@ -765,7 +767,7 @@ def test_init_kv_cache_with_kv_sharing_valid():
current_platform.is_rocm(),
reason="Attention backend FLASHINFER is not supported on ROCm.",
)
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
def test_hybrid_attention_mamba_tensor_shapes():
"""
The GPU model runner creates different views into the
KVCacheTensors for the attention and mamba layers
@@ -806,11 +808,13 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
cache_dtype="auto",
)
parallel_config = ParallelConfig()
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASHINFER)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
attention_config=attention_config,
)
layer_0 = "model.layers.0.self_attn.attn"
@@ -820,8 +824,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
layer_4 = "model.layers.4.mixer"
layer_5 = "model.layers.5.mixer"
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with set_current_vllm_config(vllm_config):
hf_config = vllm_config.model_config.hf_config
fwd_context = {}
for key in [layer_0, layer_1]:
@@ -851,10 +854,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
)
# suppress var not used error
assert fwd_context is not None
vllm_ctx = vllm_config.compilation_config.static_forward_context
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
vllm_ctx = vllm_config.compilation_config.static_forward_context
runner = GPUModelRunner(vllm_config, DEVICE)
kv_cache_spec = runner.get_kv_cache_spec()
@@ -865,94 +865,94 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
)[0]
runner.initialize_kv_cache(kv_cache_config)
# random partition of blocks
# blocks0 will be assigned to attention layers
# blocks1 will be assigned to mamba layers
num_blocks = kv_cache_config.num_blocks
ind = np.arange(num_blocks)
np.random.shuffle(ind)
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
# random partition of blocks
# blocks0 will be assigned to attention layers
# blocks1 will be assigned to mamba layers
num_blocks = kv_cache_config.num_blocks
ind = np.arange(num_blocks)
np.random.shuffle(ind)
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
# assert we are using FlashInfer
assert attn_shape[0] % num_blocks == 0
block_split_ratio = attn_shape[0] // num_blocks
# assert we are using FlashInfer
assert attn_shape[0] % num_blocks == 0
block_split_ratio = attn_shape[0] // num_blocks
# use small blocks for testing to avoid memory issues
test_block_size = min(2, len(blocks0), len(blocks1))
# use small blocks for testing to avoid memory issues
test_block_size = min(2, len(blocks0), len(blocks1))
# use non-overlapping blocks to avoid data contamination
# Split kernel blocks: first half for attention, second half for mamba
mid_point = num_blocks // 2
# use non-overlapping blocks to avoid data contamination
# Split kernel blocks: first half for attention, second half for mamba
mid_point = num_blocks // 2
# attention uses kernel blocks from first half (mapped to logical blocks)
kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
# attention uses kernel blocks from first half (mapped to logical blocks)
kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
# mamba uses kernel blocks from second half
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
# mamba uses kernel blocks from second half
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
# create small constant tensors for testing with corrected shapes
# attention: [block_size, ...] starting from dimension 2
attn_constant_shape = attn_shape[2:]
conv_constant_shape = conv_shape[1:]
ssm_constant_shape = ssm_shape[1:]
# create small constant tensors for testing with corrected shapes
# attention: [block_size, ...] starting from dimension 2
attn_constant_shape = attn_shape[2:]
conv_constant_shape = conv_shape[1:]
ssm_constant_shape = ssm_shape[1:]
attn_blocks_constant = torch.full(
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
)
conv_blocks_constant = torch.full(
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
)
ssm_blocks_constant = torch.full(
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
)
attn_blocks_constant = torch.full(
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
)
conv_blocks_constant = torch.full(
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
)
ssm_blocks_constant = torch.full(
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
)
# Fill attention blocks with constants using kv block indices
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
# Fill attention blocks with constants using kv block indices
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
for layer in [layer_0, layer_1]:
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
for i, kernel_block in enumerate(kernel_blocks_for_attention):
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
for layer in [layer_0, layer_1]:
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
for i, kernel_block in enumerate(kernel_blocks_for_attention):
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
# fill mamba blocks with constants using kernel block indices
for layer in [layer_2, layer_3, layer_4, layer_5]:
# mamba: kv_cache[0][component][kernel_block_idx, ...]
for i, kv_block in enumerate(kv_blocks_for_mamba):
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
# fill mamba blocks with constants using kernel block indices
for layer in [layer_2, layer_3, layer_4, layer_5]:
# mamba: kv_cache[0][component][kernel_block_idx, ...]
for i, kv_block in enumerate(kv_blocks_for_mamba):
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
# verify attention and mamba contents are correct
for layer in [layer_0, layer_1]:
for i, kernel_block in enumerate(kernel_blocks_for_attention):
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
expected = attn_blocks_constant[i]
# verify attention and mamba contents are correct
for layer in [layer_0, layer_1]:
for i, kernel_block in enumerate(kernel_blocks_for_attention):
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
expected = attn_blocks_constant[i]
# Check K and V separately
assert torch.equal(actual_kv[0], expected)
assert torch.equal(actual_kv[1], expected)
# Check K and V separately
assert torch.equal(actual_kv[0], expected)
assert torch.equal(actual_kv[1], expected)
for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm)
assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm)
for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm)
for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm)
def test_hybrid_block_table_initialization():