[v1] Support mamba2 (#19327)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-06-19 04:34:15 +08:00
committed by GitHub
parent ffacb222cb
commit a89209b78d
9 changed files with 582 additions and 120 deletions

View File

@@ -3,6 +3,7 @@
import copy
from dataclasses import dataclass
from math import prod
from typing import Optional
import torch
@@ -154,6 +155,29 @@ class SlidingWindowSpec(AttentionSpec):
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
@dataclass
class MambaSpec(KVCacheSpec):
shapes: tuple[tuple[int, ...], ...]
dtype: torch.dtype
def __post_init__(self):
self.num_elements = sum(prod(shape) for shape in self.shapes)
@property
def type_id(self) -> str:
return f"mamba_{self.shapes}_{self.dtype}"
@property
def page_size_bytes(self) -> int:
return self.num_elements * get_dtype_size(self.dtype)
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
# We allocate 1 block for each request now, so max_memory_usage_bytes is
# the same as page_size_bytes.
# Need to update this when supporting prefix caching.
return self.page_size_bytes
@dataclass
class KVCacheTensor:
"""