[Model] Add LFM2 architecture (#22845)
Signed-off-by: Paul Pak <paulpak58@gmail.com>
This commit is contained in:
@@ -4,6 +4,8 @@ from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
|
||||
from vllm.v1.attention.backends.short_conv_attn import (
|
||||
ShortConvAttentionBackend)
|
||||
|
||||
|
||||
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
|
||||
@@ -13,6 +15,8 @@ def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
|
||||
return Mamba2AttentionBackend
|
||||
if mamba_type == "linear_attention":
|
||||
return LinearAttentionBackend
|
||||
if mamba_type == "short_conv":
|
||||
return ShortConvAttentionBackend
|
||||
|
||||
raise NotImplementedError(f"Mamba Attention type {mamba_type} is not "
|
||||
"supported yet.")
|
||||
|
||||
81
vllm/v1/attention/backends/short_conv_attn.py
Normal file
81
vllm/v1/attention/backends/short_conv_attn.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class ShortConvAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
|
||||
return ShortConvAttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShortConvAttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
has_initial_states: torch.Tensor
|
||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||
|
||||
# For causal_conv1d
|
||||
nums_dict: Optional[dict] = None
|
||||
cu_seqlen: Optional[int] = None
|
||||
batch_ptr: Optional[torch.tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.tensor] = None
|
||||
|
||||
|
||||
class ShortConvAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[ShortConvAttentionMetadata]):
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> ShortConvAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=1))
|
||||
has_initial_states = None
|
||||
if num_prefills > 0:
|
||||
#[batch,]
|
||||
has_initial_states_cpu = (
|
||||
common_attn_metadata.
|
||||
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
|
||||
has_initial_states = has_initial_states_cpu.to(
|
||||
query_start_loc.device)
|
||||
|
||||
attn_metadata = ShortConvAttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
has_initial_states=has_initial_states,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
)
|
||||
return attn_metadata
|
||||
Reference in New Issue
Block a user