[V1] - Split Prefill and Decode for Mamba1 models (#22653)

Signed-off-by: amirk <amirk@ai21.com>
Signed-off-by: asafg <asafg@ai21.com>
Co-authored-by: asafg <asafg@ai21.com>
Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
amirai21
2025-08-15 11:59:52 +03:00
committed by GitHub
parent 5406ebf5c9
commit fe91ce9591
3 changed files with 251 additions and 93 deletions

View File

@@ -2,14 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@@ -25,12 +26,15 @@ class Mamba1AttentionMetadata:
query_start_loc: torch.Tensor
context_lens_tensor: torch.Tensor
state_indices_tensor: torch.Tensor
has_initial_states: torch.Tensor
has_initial_states: Optional[torch.Tensor]
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
class Mamba1AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba1AttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1
def __init__(
@@ -57,11 +61,23 @@ class Mamba1AttentionMetadataBuilder(
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
query_start_loc.device)
has_initial_states = (context_lens_tensor > 0)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
has_initial_states = None
if num_prefills > 0:
has_initial_states = context_lens_tensor > 0
return Mamba1AttentionMetadata(
query_start_loc=query_start_loc,
context_lens_tensor=context_lens_tensor,
has_initial_states=has_initial_states,
state_indices_tensor=state_indices_tensor,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
)