[torch.compile] remove compilation_context and simplify code (#10838)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-12-02 22:19:02 -08:00
committed by GitHub
parent 21fe7b481a
commit dc5ce861bf
14 changed files with 128 additions and 143 deletions

View File

@@ -5,8 +5,8 @@ Run `pytest tests/models/test_mamba.py`.
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.worker.model_runner import _get_graph_batch_size
from ...utils import check_outputs_equal
@@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
while len(example_prompts) == VllmConfig.get_graph_batch_size(
len(example_prompts)):
example_prompts.append(example_prompts[0])
try: