[v1] Pass BlockTable and KVCacheSpec to AttentionMetadataBuilders (#17483)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -1,14 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
def initialize_kv_cache(runner: GPUModelRunner):
|
||||
"""
|
||||
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
|
||||
"""
|
||||
kv_cache_spec = FullAttentionSpec(block_size=16,
|
||||
num_kv_heads=1,
|
||||
head_size=64,
|
||||
dtype=torch.float16,
|
||||
use_mla=False)
|
||||
runner.attn_metadata_builder = runner.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(runner), kv_cache_spec, runner.input_batch.block_table)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_runner():
|
||||
scheduler_config = SchedulerConfig(
|
||||
@@ -38,7 +55,9 @@ def model_runner():
|
||||
)
|
||||
|
||||
device = "cuda"
|
||||
return GPUModelRunner(vllm_config, device)
|
||||
runner = GPUModelRunner(vllm_config, device)
|
||||
initialize_kv_cache(runner)
|
||||
return runner
|
||||
|
||||
|
||||
def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
|
||||
Reference in New Issue
Block a user