[torch.compile] remove compilation_context and simplify code (#10838)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -5,8 +5,8 @@ Run `pytest tests/models/test_mamba.py`.
|
||||
import pytest
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.worker.model_runner import _get_graph_batch_size
|
||||
|
||||
from ...utils import check_outputs_equal
|
||||
|
||||
@@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding(
|
||||
# This test is for verifying that mamba cache is padded to CG captured
|
||||
# batch size. If it's not, a torch RuntimeError will be raised because
|
||||
# tensor dimensions aren't compatible
|
||||
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
|
||||
while len(example_prompts) == VllmConfig.get_graph_batch_size(
|
||||
len(example_prompts)):
|
||||
example_prompts.append(example_prompts[0])
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user