[Model Runner V2] Support attention group (#35036)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-02-21 16:42:53 -08:00
committed by GitHub
parent 74d90b1ce4
commit b71fbd06e2
5 changed files with 86 additions and 54 deletions

View File

@@ -7,17 +7,14 @@ import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheConfig,
KVCacheSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.worker.utils import bind_kv_cache
from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
@@ -35,29 +32,56 @@ def init_attn_backend(
kv_cache_config: KVCacheConfig, vllm_config: VllmConfig, device: torch.device
):
attn_backends: dict[str, type[AttentionBackend]] = {}
attn_metadata_builders: list[AttentionMetadataBuilder] = []
flashinfer_workspace: torch.Tensor | None = None
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
attn_groups: list[list[AttentionGroup]] = []
attn_backend_workspace: torch.Tensor | None = None
for kv_cache_group_id, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups
):
layer_names = kv_cache_group_spec.layer_names
any_layer_name = next(iter(layer_names))
layer_type = cast(type[Any], AttentionLayerBase)
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
attn_backend = attn_layers[any_layer_name].get_attn_backend()
group_map: dict[tuple[tuple[str, str], KVCacheSpec], AttentionGroup] = {}
group_order: list[tuple[tuple[str, str], KVCacheSpec]] = []
for layer_name in layer_names:
attn_backend = attn_layers[layer_name].get_attn_backend()
attn_backends[layer_name] = attn_backend
attn_metadata_builder = attn_backend.get_builder_cls()(
kv_cache_group_spec.kv_cache_spec, layer_names, vllm_config, device
)
attn_metadata_builders.append(attn_metadata_builder) # type: ignore
layer_kv_cache_spec: KVCacheSpec = kv_cache_group_spec.kv_cache_spec
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name]
if attn_backend.get_name() == "FLASHINFER":
if flashinfer_workspace is None:
flashinfer_workspace = attn_metadata_builder._get_workspace_buffer()
key = (attn_backend.full_cls_name(), layer_kv_cache_spec)
if key not in group_map:
group_map[key] = AttentionGroup(
attn_backend,
[layer_name],
layer_kv_cache_spec,
kv_cache_group_id,
)
group_order.append(key)
else:
attn_metadata_builder.set_workspace_buffer(flashinfer_workspace)
return attn_backends, attn_metadata_builders
group_map[key].layer_names.append(layer_name)
groups = [group_map[key] for key in group_order]
for group in groups:
group.create_metadata_builders(
vllm_config=vllm_config,
device=device,
kernel_block_size=None,
num_metadata_builders=1,
)
builder = group.get_metadata_builder(0)
if attn_backend_workspace is None:
if hasattr(builder, "_get_workspace_buffer"):
attn_backend_workspace = builder._get_workspace_buffer()
else:
if hasattr(builder, "set_workspace_buffer"):
builder.set_workspace_buffer(attn_backend_workspace)
attn_groups.append(groups)
return attn_backends, attn_groups
def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device):
@@ -144,7 +168,7 @@ def build_slot_mappings_by_layer(
def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder],
attn_groups: list[list[AttentionGroup]],
num_reqs: int,
num_tokens: int,
query_start_loc_gpu: torch.Tensor,
@@ -162,8 +186,8 @@ def build_attn_metadata(
dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs]
attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
for i, kv_cache_spec in enumerate(kv_cache_groups):
num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
for i in range(num_kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
@@ -181,10 +205,11 @@ def build_attn_metadata(
dcp_local_seq_lens=dcp_local_seq_lens,
)
attn_metadata_builder = attn_metadata_builders[i]
metadata = attn_metadata_builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata
for attn_group in attn_groups[i]:
attn_metadata_builder = attn_group.get_metadata_builder(0)
metadata = attn_metadata_builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = metadata
return attn_metadata

View File

@@ -13,7 +13,6 @@ from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
@@ -22,6 +21,7 @@ from vllm.v1.worker.gpu.attn_utils import (
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.utils import AttentionGroup
class CudaGraphManager:
@@ -83,7 +83,7 @@ class CudaGraphManager:
mrope_positions: torch.Tensor | None,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
has_lora: bool = False,
uniform_decode: bool = False,
@@ -116,7 +116,7 @@ class CudaGraphManager:
num_tokens,
input_buffers,
block_tables,
attn_metadata_builders,
attn_groups,
self.max_model_len,
kv_cache_config,
uniform_decode_query_len=(
@@ -232,7 +232,7 @@ class CudaGraphManager:
mrope_positions: torch.Tensor | None,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
has_lora: bool = False,
) -> None:
@@ -244,7 +244,7 @@ class CudaGraphManager:
mrope_positions=mrope_positions,
inputs_embeds=inputs_embeds,
block_tables=block_tables,
attn_metadata_builders=attn_metadata_builders,
attn_groups=attn_groups,
kv_cache_config=kv_cache_config,
has_lora=has_lora,
)
@@ -286,6 +286,16 @@ class CudaGraphManager:
cudagraph_mode = self.cudagraph_mode.decode_mode()
else:
cudagraph_mode = self.cudagraph_mode.mixed_mode()
if (
cudagraph_mode == CUDAGraphMode.FULL
and cudagraph_size is not None
and cudagraph_size not in self.graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode = CUDAGraphMode.NONE
cudagraph_size = None
return cudagraph_mode, cudagraph_size
def run_fullgraph(self, num_tokens: int) -> torch.Tensor:
@@ -354,7 +364,7 @@ def prepare_inputs_to_capture(
num_tokens: int,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
attn_groups: list[list[AttentionGroup]],
max_model_len: int,
kv_cache_config: KVCacheConfig,
uniform_decode_query_len: int = 0,
@@ -386,7 +396,7 @@ def prepare_inputs_to_capture(
)
attn_metadata = build_attn_metadata(
attn_metadata_builders=attn_metadata_builders,
attn_groups=attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,

View File

@@ -283,7 +283,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cp_interleave=self.cp_interleave,
)
self.attn_backends, self.attn_metadata_builders = init_attn_backend(
self.attn_backends, self.attn_groups = init_attn_backend(
self.kv_cache_config, self.vllm_config, self.device
)
check_attention_cp_compatibility(self.vllm_config)
@@ -291,7 +291,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# HACK(woosuk)
self.speculator.set_attn(
self.kv_cache_config,
self.attn_metadata_builders,
self.attn_groups,
self.block_tables,
)
@@ -305,9 +305,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
# Attention groups are not supported.
self.attn_groups = [] # type: ignore
def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
slot_mappings = self.block_tables.get_dummy_slot_mappings(
@@ -317,7 +314,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mappings, self.kv_cache_config
)
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
attn_groups=self.attn_groups,
num_reqs=input_batch.num_reqs,
num_tokens=input_batch.num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
@@ -477,7 +474,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
mrope_positions=mrope_positions,
inputs_embeds=inputs_embeds,
block_tables=self.block_tables,
attn_metadata_builders=self.attn_metadata_builders,
attn_groups=self.attn_groups,
kv_cache_config=self.kv_cache_config,
has_lora=self.lora_config is not None,
)
@@ -712,7 +709,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Layer name -> attention metadata.
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
attn_groups=self.attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,

View File

@@ -7,7 +7,6 @@ import torch
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import (
@@ -17,6 +16,7 @@ from vllm.v1.worker.gpu.cudagraph_utils import (
)
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.utils import AttentionGroup
class EagleCudaGraphManager:
@@ -60,7 +60,7 @@ class EagleCudaGraphManager:
generate_fn: Callable,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> None:
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
@@ -77,7 +77,7 @@ class EagleCudaGraphManager:
num_tokens,
input_buffers,
block_tables,
attn_metadata_builders,
attn_groups,
self.max_model_len,
kv_cache_config,
uniform_decode_query_len=1,
@@ -150,7 +150,7 @@ class EagleCudaGraphManager:
generate_fn: Callable,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> None:
if self.cudagraph_mode == CUDAGraphMode.NONE:
@@ -165,7 +165,7 @@ class EagleCudaGraphManager:
generate_fn=generate_fn,
input_buffers=input_buffers,
block_tables=block_tables,
attn_metadata_builders=attn_metadata_builders,
attn_groups=attn_groups,
kv_cache_config=kv_cache_config,
)

View File

@@ -10,7 +10,6 @@ from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
@@ -21,6 +20,7 @@ from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
from vllm.v1.worker.utils import AttentionGroup
logger = init_logger(__name__)
@@ -78,11 +78,11 @@ class EagleSpeculator:
def set_attn(
self,
kv_cache_config: KVCacheConfig,
attn_metadata_builders: list[AttentionMetadataBuilder],
attn_groups: list[list[AttentionGroup]],
block_tables: BlockTables,
) -> None:
self.kv_cache_config = kv_cache_config
self.attn_metadata_builders = attn_metadata_builders
self.attn_groups = attn_groups
self.block_tables = block_tables
@torch.inference_mode()
@@ -174,7 +174,7 @@ class EagleSpeculator:
self.generate_draft,
self.input_buffers,
self.block_tables,
self.attn_metadata_builders,
self.attn_groups,
self.kv_cache_config,
)
@@ -298,7 +298,7 @@ class EagleSpeculator:
# FIXME(woosuk): This is UNSAFE!!
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
attn_groups=self.attn_groups,
num_reqs=num_reqs,
num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc,