[Model Runner V2] Support attention group (#35036)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -7,17 +7,14 @@ import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache
|
||||
|
||||
|
||||
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||
@@ -35,29 +32,56 @@ def init_attn_backend(
|
||||
kv_cache_config: KVCacheConfig, vllm_config: VllmConfig, device: torch.device
|
||||
):
|
||||
attn_backends: dict[str, type[AttentionBackend]] = {}
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
||||
flashinfer_workspace: torch.Tensor | None = None
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
attn_groups: list[list[AttentionGroup]] = []
|
||||
attn_backend_workspace: torch.Tensor | None = None
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
kv_cache_config.kv_cache_groups
|
||||
):
|
||||
layer_names = kv_cache_group_spec.layer_names
|
||||
any_layer_name = next(iter(layer_names))
|
||||
|
||||
layer_type = cast(type[Any], AttentionLayerBase)
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
|
||||
attn_backend = attn_layers[any_layer_name].get_attn_backend()
|
||||
|
||||
group_map: dict[tuple[tuple[str, str], KVCacheSpec], AttentionGroup] = {}
|
||||
group_order: list[tuple[tuple[str, str], KVCacheSpec]] = []
|
||||
|
||||
for layer_name in layer_names:
|
||||
attn_backend = attn_layers[layer_name].get_attn_backend()
|
||||
attn_backends[layer_name] = attn_backend
|
||||
|
||||
attn_metadata_builder = attn_backend.get_builder_cls()(
|
||||
kv_cache_group_spec.kv_cache_spec, layer_names, vllm_config, device
|
||||
)
|
||||
attn_metadata_builders.append(attn_metadata_builder) # type: ignore
|
||||
layer_kv_cache_spec: KVCacheSpec = kv_cache_group_spec.kv_cache_spec
|
||||
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name]
|
||||
|
||||
if attn_backend.get_name() == "FLASHINFER":
|
||||
if flashinfer_workspace is None:
|
||||
flashinfer_workspace = attn_metadata_builder._get_workspace_buffer()
|
||||
key = (attn_backend.full_cls_name(), layer_kv_cache_spec)
|
||||
if key not in group_map:
|
||||
group_map[key] = AttentionGroup(
|
||||
attn_backend,
|
||||
[layer_name],
|
||||
layer_kv_cache_spec,
|
||||
kv_cache_group_id,
|
||||
)
|
||||
group_order.append(key)
|
||||
else:
|
||||
attn_metadata_builder.set_workspace_buffer(flashinfer_workspace)
|
||||
return attn_backends, attn_metadata_builders
|
||||
group_map[key].layer_names.append(layer_name)
|
||||
|
||||
groups = [group_map[key] for key in group_order]
|
||||
for group in groups:
|
||||
group.create_metadata_builders(
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
kernel_block_size=None,
|
||||
num_metadata_builders=1,
|
||||
)
|
||||
builder = group.get_metadata_builder(0)
|
||||
if attn_backend_workspace is None:
|
||||
if hasattr(builder, "_get_workspace_buffer"):
|
||||
attn_backend_workspace = builder._get_workspace_buffer()
|
||||
else:
|
||||
if hasattr(builder, "set_workspace_buffer"):
|
||||
builder.set_workspace_buffer(attn_backend_workspace)
|
||||
attn_groups.append(groups)
|
||||
return attn_backends, attn_groups
|
||||
|
||||
|
||||
def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device):
|
||||
@@ -144,7 +168,7 @@ def build_slot_mappings_by_layer(
|
||||
|
||||
|
||||
def build_attn_metadata(
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
query_start_loc_gpu: torch.Tensor,
|
||||
@@ -162,8 +186,8 @@ def build_attn_metadata(
|
||||
dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
kv_cache_groups = kv_cache_config.kv_cache_groups
|
||||
for i, kv_cache_spec in enumerate(kv_cache_groups):
|
||||
num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||
for i in range(num_kv_cache_groups):
|
||||
block_table = block_tables[i]
|
||||
slot_mapping = slot_mappings[i]
|
||||
|
||||
@@ -181,10 +205,11 @@ def build_attn_metadata(
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
)
|
||||
|
||||
attn_metadata_builder = attn_metadata_builders[i]
|
||||
metadata = attn_metadata_builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
for layer_name in kv_cache_spec.layer_names:
|
||||
attn_metadata[layer_name] = metadata
|
||||
for attn_group in attn_groups[i]:
|
||||
attn_metadata_builder = attn_group.get_metadata_builder(0)
|
||||
metadata = attn_metadata_builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = metadata
|
||||
return attn_metadata
|
||||
|
||||
@@ -13,7 +13,6 @@ from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
@@ -22,6 +21,7 @@ from vllm.v1.worker.gpu.attn_utils import (
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
class CudaGraphManager:
|
||||
@@ -83,7 +83,7 @@ class CudaGraphManager:
|
||||
mrope_positions: torch.Tensor | None,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
has_lora: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
@@ -116,7 +116,7 @@ class CudaGraphManager:
|
||||
num_tokens,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_metadata_builders,
|
||||
attn_groups,
|
||||
self.max_model_len,
|
||||
kv_cache_config,
|
||||
uniform_decode_query_len=(
|
||||
@@ -232,7 +232,7 @@ class CudaGraphManager:
|
||||
mrope_positions: torch.Tensor | None,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
has_lora: bool = False,
|
||||
) -> None:
|
||||
@@ -244,7 +244,7 @@ class CudaGraphManager:
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
block_tables=block_tables,
|
||||
attn_metadata_builders=attn_metadata_builders,
|
||||
attn_groups=attn_groups,
|
||||
kv_cache_config=kv_cache_config,
|
||||
has_lora=has_lora,
|
||||
)
|
||||
@@ -286,6 +286,16 @@ class CudaGraphManager:
|
||||
cudagraph_mode = self.cudagraph_mode.decode_mode()
|
||||
else:
|
||||
cudagraph_mode = self.cudagraph_mode.mixed_mode()
|
||||
|
||||
if (
|
||||
cudagraph_mode == CUDAGraphMode.FULL
|
||||
and cudagraph_size is not None
|
||||
and cudagraph_size not in self.graphs
|
||||
):
|
||||
# If graph wasn't captured yet, fall back to eager.
|
||||
# This might happen when the dummy run is called before capture.
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
cudagraph_size = None
|
||||
return cudagraph_mode, cudagraph_size
|
||||
|
||||
def run_fullgraph(self, num_tokens: int) -> torch.Tensor:
|
||||
@@ -354,7 +364,7 @@ def prepare_inputs_to_capture(
|
||||
num_tokens: int,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
max_model_len: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
uniform_decode_query_len: int = 0,
|
||||
@@ -386,7 +396,7 @@ def prepare_inputs_to_capture(
|
||||
)
|
||||
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=attn_metadata_builders,
|
||||
attn_groups=attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
|
||||
@@ -283,7 +283,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
cp_interleave=self.cp_interleave,
|
||||
)
|
||||
|
||||
self.attn_backends, self.attn_metadata_builders = init_attn_backend(
|
||||
self.attn_backends, self.attn_groups = init_attn_backend(
|
||||
self.kv_cache_config, self.vllm_config, self.device
|
||||
)
|
||||
check_attention_cp_compatibility(self.vllm_config)
|
||||
@@ -291,7 +291,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# HACK(woosuk)
|
||||
self.speculator.set_attn(
|
||||
self.kv_cache_config,
|
||||
self.attn_metadata_builders,
|
||||
self.attn_groups,
|
||||
self.block_tables,
|
||||
)
|
||||
|
||||
@@ -305,9 +305,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
|
||||
|
||||
# Attention groups are not supported.
|
||||
self.attn_groups = [] # type: ignore
|
||||
|
||||
def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
|
||||
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
|
||||
slot_mappings = self.block_tables.get_dummy_slot_mappings(
|
||||
@@ -317,7 +314,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=input_batch.num_reqs,
|
||||
num_tokens=input_batch.num_tokens,
|
||||
query_start_loc_gpu=input_batch.query_start_loc,
|
||||
@@ -477,7 +474,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
block_tables=self.block_tables,
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
attn_groups=self.attn_groups,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
has_lora=self.lora_config is not None,
|
||||
)
|
||||
@@ -712,7 +709,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Layer name -> attention metadata.
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
|
||||
@@ -7,7 +7,6 @@ import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import (
|
||||
@@ -17,6 +16,7 @@ from vllm.v1.worker.gpu.cudagraph_utils import (
|
||||
)
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
class EagleCudaGraphManager:
|
||||
@@ -60,7 +60,7 @@ class EagleCudaGraphManager:
|
||||
generate_fn: Callable,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
|
||||
@@ -77,7 +77,7 @@ class EagleCudaGraphManager:
|
||||
num_tokens,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_metadata_builders,
|
||||
attn_groups,
|
||||
self.max_model_len,
|
||||
kv_cache_config,
|
||||
uniform_decode_query_len=1,
|
||||
@@ -150,7 +150,7 @@ class EagleCudaGraphManager:
|
||||
generate_fn: Callable,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
if self.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
@@ -165,7 +165,7 @@ class EagleCudaGraphManager:
|
||||
generate_fn=generate_fn,
|
||||
input_buffers=input_buffers,
|
||||
block_tables=block_tables,
|
||||
attn_metadata_builders=attn_metadata_builders,
|
||||
attn_groups=attn_groups,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
@@ -21,6 +20,7 @@ from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -78,11 +78,11 @@ class EagleSpeculator:
|
||||
def set_attn(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
block_tables: BlockTables,
|
||||
) -> None:
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.attn_metadata_builders = attn_metadata_builders
|
||||
self.attn_groups = attn_groups
|
||||
self.block_tables = block_tables
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -174,7 +174,7 @@ class EagleSpeculator:
|
||||
self.generate_draft,
|
||||
self.input_buffers,
|
||||
self.block_tables,
|
||||
self.attn_metadata_builders,
|
||||
self.attn_groups,
|
||||
self.kv_cache_config,
|
||||
)
|
||||
|
||||
@@ -298,7 +298,7 @@ class EagleSpeculator:
|
||||
|
||||
# FIXME(woosuk): This is UNSAFE!!
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_reqs,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
|
||||
Reference in New Issue
Block a user