[core] clean up cudagraph batchsize padding logic (#10996)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -5,7 +5,7 @@ Run `pytest tests/models/test_mamba.py`.
|
||||
import pytest
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from ...utils import check_outputs_equal
|
||||
@@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding(
|
||||
# This test is for verifying that mamba cache is padded to CG captured
|
||||
# batch size. If it's not, a torch RuntimeError will be raised because
|
||||
# tensor dimensions aren't compatible
|
||||
while len(example_prompts) == VllmConfig.get_graph_batch_size(
|
||||
vllm_config = EngineArgs(model=model).create_engine_config()
|
||||
while len(example_prompts) == vllm_config.pad_for_cudagraph(
|
||||
len(example_prompts)):
|
||||
example_prompts.append(example_prompts[0])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user