[core] clean up cudagraph batchsize padding logic (#10996)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-12-12 22:57:50 -08:00
committed by GitHub
parent 34f1a806d5
commit be39e3cd18
11 changed files with 150 additions and 104 deletions

View File

@@ -5,7 +5,7 @@ Run `pytest tests/models/test_mamba.py`.
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams
from ...utils import check_outputs_equal
@@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == VllmConfig.get_graph_batch_size(
vllm_config = EngineArgs(model=model).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(
len(example_prompts)):
example_prompts.append(example_prompts[0])