[core] clean up cudagraph batchsize padding logic (#10996)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-12-12 22:57:50 -08:00
committed by GitHub
parent 34f1a806d5
commit be39e3cd18
11 changed files with 150 additions and 104 deletions

View File

@@ -1,7 +1,7 @@
import pytest
from tests.utils import multi_gpu_test
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams
from ...utils import check_outputs_equal
@@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == VllmConfig.get_graph_batch_size(
vllm_config = EngineArgs(model=model).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(
len(example_prompts)):
example_prompts.append(example_prompts[0])