Signed-off-by: haosdent <haosdent@gmail.com> Co-authored-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
@@ -1199,3 +1199,123 @@ def test_is_uniform_decode() -> None:
|
||||
num_reqs=15,
|
||||
force_uniform_decode=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="Attention backend FLASHINFER is not supported on ROCm.",
|
||||
)
|
||||
def test_cudagraph_sizes_capped_for_mamba_cache():
|
||||
"""Test that cudagraph capture sizes are capped to num_blocks for
|
||||
hybrid models with Mamba layers.
|
||||
|
||||
See: https://github.com/vllm-project/vllm/issues/34094
|
||||
"""
|
||||
set_random_seed(42)
|
||||
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": "0",
|
||||
"LOCAL_RANK": "0",
|
||||
"WORLD_SIZE": "1",
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
from tests.utils import ensure_current_vllm_config
|
||||
|
||||
with ensure_current_vllm_config():
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=1)
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
model_config = ModelConfig(
|
||||
model="ibm-granite/granite-4.0-tiny-preview",
|
||||
dtype="float16",
|
||||
)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=10,
|
||||
max_num_batched_tokens=512,
|
||||
max_model_len=512,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size=BLOCK_SIZE,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
)
|
||||
parallel_config = ParallelConfig()
|
||||
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASHINFER)
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
parallel_config=parallel_config,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
fwd_context = {}
|
||||
for key in ["model.layers.0.self_attn.attn", "model.layers.1.self_attn.attn"]:
|
||||
fwd_context[key] = Attention(
|
||||
num_heads=model_config.get_num_attention_heads(parallel_config),
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
scale=1.0,
|
||||
prefix=key,
|
||||
)
|
||||
for key in [
|
||||
"model.layers.2.mixer",
|
||||
"model.layers.3.mixer",
|
||||
"model.layers.4.mixer",
|
||||
"model.layers.5.mixer",
|
||||
]:
|
||||
fwd_context[key] = MambaMixer2(
|
||||
hidden_size=hf_config.hidden_size,
|
||||
ssm_state_size=hf_config.mamba_d_state,
|
||||
conv_kernel_size=hf_config.mamba_d_conv,
|
||||
intermediate_size=hf_config.mamba_expand * hf_config.hidden_size,
|
||||
use_conv_bias=hf_config.mamba_conv_bias,
|
||||
use_bias=hf_config.mamba_proj_bias,
|
||||
n_groups=hf_config.mamba_n_groups,
|
||||
num_heads=hf_config.mamba_n_heads,
|
||||
head_dim=hf_config.mamba_d_head,
|
||||
rms_norm_eps=hf_config.rms_norm_eps,
|
||||
activation=hf_config.hidden_act,
|
||||
cache_config=cache_config,
|
||||
model_config=model_config,
|
||||
prefix=key,
|
||||
)
|
||||
assert fwd_context is not None
|
||||
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
kv_cache_spec = runner.get_kv_cache_spec()
|
||||
|
||||
available_memory = 5 * GiB_bytes
|
||||
kv_cache_config = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_spec], [available_memory]
|
||||
)[0]
|
||||
num_blocks = kv_cache_config.num_blocks
|
||||
|
||||
# Set max_cudagraph_capture_size to a value larger than num_blocks
|
||||
# to trigger the Mamba capping logic.
|
||||
large_max = num_blocks + 100
|
||||
compilation_config = vllm_config.compilation_config
|
||||
compilation_config.max_cudagraph_capture_size = large_max
|
||||
compilation_config.cudagraph_capture_sizes = [
|
||||
s for s in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] if s <= large_max
|
||||
]
|
||||
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
# After initialization, cudagraph sizes should be capped
|
||||
assert compilation_config.max_cudagraph_capture_size <= num_blocks
|
||||
assert all(s <= num_blocks for s in compilation_config.cudagraph_capture_sizes)
|
||||
# Invariant: last element == max
|
||||
if compilation_config.cudagraph_capture_sizes:
|
||||
assert (
|
||||
compilation_config.cudagraph_capture_sizes[-1]
|
||||
== compilation_config.max_cudagraph_capture_size
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user