[Attention][UX][1/N] Add AttentionConfig and change attention env vars to CLI arguments (#26315)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -6,8 +6,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import MultipleOf
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import (
|
||||
AttentionConfig,
|
||||
CacheConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
@@ -765,7 +767,7 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
current_platform.is_rocm(),
|
||||
reason="Attention backend FLASHINFER is not supported on ROCm.",
|
||||
)
|
||||
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
def test_hybrid_attention_mamba_tensor_shapes():
|
||||
"""
|
||||
The GPU model runner creates different views into the
|
||||
KVCacheTensors for the attention and mamba layers
|
||||
@@ -806,11 +808,13 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
cache_dtype="auto",
|
||||
)
|
||||
parallel_config = ParallelConfig()
|
||||
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASHINFER)
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
parallel_config=parallel_config,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
@@ -820,8 +824,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
layer_4 = "model.layers.4.mixer"
|
||||
layer_5 = "model.layers.5.mixer"
|
||||
|
||||
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
fwd_context = {}
|
||||
for key in [layer_0, layer_1]:
|
||||
@@ -851,10 +854,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
)
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
kv_cache_spec = runner.get_kv_cache_spec()
|
||||
@@ -865,94 +865,94 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
)[0]
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
# random partition of blocks
|
||||
# blocks0 will be assigned to attention layers
|
||||
# blocks1 will be assigned to mamba layers
|
||||
num_blocks = kv_cache_config.num_blocks
|
||||
ind = np.arange(num_blocks)
|
||||
np.random.shuffle(ind)
|
||||
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
|
||||
# random partition of blocks
|
||||
# blocks0 will be assigned to attention layers
|
||||
# blocks1 will be assigned to mamba layers
|
||||
num_blocks = kv_cache_config.num_blocks
|
||||
ind = np.arange(num_blocks)
|
||||
np.random.shuffle(ind)
|
||||
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
|
||||
|
||||
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
|
||||
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
|
||||
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
|
||||
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
|
||||
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
|
||||
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
|
||||
|
||||
# assert we are using FlashInfer
|
||||
assert attn_shape[0] % num_blocks == 0
|
||||
block_split_ratio = attn_shape[0] // num_blocks
|
||||
# assert we are using FlashInfer
|
||||
assert attn_shape[0] % num_blocks == 0
|
||||
block_split_ratio = attn_shape[0] // num_blocks
|
||||
|
||||
# use small blocks for testing to avoid memory issues
|
||||
test_block_size = min(2, len(blocks0), len(blocks1))
|
||||
# use small blocks for testing to avoid memory issues
|
||||
test_block_size = min(2, len(blocks0), len(blocks1))
|
||||
|
||||
# use non-overlapping blocks to avoid data contamination
|
||||
# Split kernel blocks: first half for attention, second half for mamba
|
||||
mid_point = num_blocks // 2
|
||||
# use non-overlapping blocks to avoid data contamination
|
||||
# Split kernel blocks: first half for attention, second half for mamba
|
||||
mid_point = num_blocks // 2
|
||||
|
||||
# attention uses kernel blocks from first half (mapped to logical blocks)
|
||||
kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
|
||||
# attention uses kernel blocks from first half (mapped to logical blocks)
|
||||
kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
|
||||
|
||||
# mamba uses kernel blocks from second half
|
||||
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
|
||||
# mamba uses kernel blocks from second half
|
||||
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
|
||||
|
||||
# create small constant tensors for testing with corrected shapes
|
||||
# attention: [block_size, ...] starting from dimension 2
|
||||
attn_constant_shape = attn_shape[2:]
|
||||
conv_constant_shape = conv_shape[1:]
|
||||
ssm_constant_shape = ssm_shape[1:]
|
||||
# create small constant tensors for testing with corrected shapes
|
||||
# attention: [block_size, ...] starting from dimension 2
|
||||
attn_constant_shape = attn_shape[2:]
|
||||
conv_constant_shape = conv_shape[1:]
|
||||
ssm_constant_shape = ssm_shape[1:]
|
||||
|
||||
attn_blocks_constant = torch.full(
|
||||
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
|
||||
)
|
||||
conv_blocks_constant = torch.full(
|
||||
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
|
||||
)
|
||||
ssm_blocks_constant = torch.full(
|
||||
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
|
||||
)
|
||||
attn_blocks_constant = torch.full(
|
||||
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
|
||||
)
|
||||
conv_blocks_constant = torch.full(
|
||||
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
|
||||
)
|
||||
ssm_blocks_constant = torch.full(
|
||||
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
|
||||
)
|
||||
|
||||
# Fill attention blocks with constants using kv block indices
|
||||
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
|
||||
# Fill attention blocks with constants using kv block indices
|
||||
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
|
||||
|
||||
for layer in [layer_0, layer_1]:
|
||||
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
|
||||
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
||||
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
|
||||
for layer in [layer_0, layer_1]:
|
||||
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
|
||||
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
||||
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
|
||||
|
||||
# fill mamba blocks with constants using kernel block indices
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
# mamba: kv_cache[0][component][kernel_block_idx, ...]
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
|
||||
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
|
||||
# fill mamba blocks with constants using kernel block indices
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
# mamba: kv_cache[0][component][kernel_block_idx, ...]
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
|
||||
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
|
||||
|
||||
# verify attention and mamba contents are correct
|
||||
for layer in [layer_0, layer_1]:
|
||||
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
||||
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
|
||||
expected = attn_blocks_constant[i]
|
||||
# verify attention and mamba contents are correct
|
||||
for layer in [layer_0, layer_1]:
|
||||
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
||||
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
|
||||
expected = attn_blocks_constant[i]
|
||||
|
||||
# Check K and V separately
|
||||
assert torch.equal(actual_kv[0], expected)
|
||||
assert torch.equal(actual_kv[1], expected)
|
||||
# Check K and V separately
|
||||
assert torch.equal(actual_kv[0], expected)
|
||||
assert torch.equal(actual_kv[1], expected)
|
||||
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
||||
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
||||
expected_conv = conv_blocks_constant[i]
|
||||
expected_ssm = ssm_blocks_constant[i]
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
||||
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
||||
expected_conv = conv_blocks_constant[i]
|
||||
expected_ssm = ssm_blocks_constant[i]
|
||||
|
||||
assert torch.equal(actual_conv, expected_conv)
|
||||
assert torch.equal(actual_ssm, expected_ssm)
|
||||
assert torch.equal(actual_conv, expected_conv)
|
||||
assert torch.equal(actual_ssm, expected_ssm)
|
||||
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
||||
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
||||
expected_conv = conv_blocks_constant[i]
|
||||
expected_ssm = ssm_blocks_constant[i]
|
||||
assert torch.equal(actual_conv, expected_conv)
|
||||
assert torch.equal(actual_ssm, expected_ssm)
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
||||
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
||||
expected_conv = conv_blocks_constant[i]
|
||||
expected_ssm = ssm_blocks_constant[i]
|
||||
assert torch.equal(actual_conv, expected_conv)
|
||||
assert torch.equal(actual_ssm, expected_ssm)
|
||||
|
||||
|
||||
def test_hybrid_block_table_initialization():
|
||||
|
||||
Reference in New Issue
Block a user