[V1] Remove V0 code paths for Hybrid models (#25400)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -20,10 +20,7 @@ class MambaBase(AttentionLayerBase):
|
||||
|
||||
# Contains the KV cache (mamba state) for the layer
|
||||
# in the shape specified by `self.get_state_shape`.
|
||||
# The outer list is for v0 PP virtual engine. Though this code path
|
||||
# only runs for v1, we have to do this to unify with the interface
|
||||
# of Attention + v0 PP.
|
||||
kv_cache: list[Iterable[torch.Tensor]]
|
||||
kv_cache: tuple[torch.Tensor, ...]
|
||||
|
||||
@abstractmethod
|
||||
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
|
||||
|
||||
@@ -15,7 +15,6 @@ import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
@@ -42,8 +41,6 @@ if TYPE_CHECKING:
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.model_executor.models.minimax_cache import MinimaxCacheParams
|
||||
|
||||
|
||||
class MiniMaxText01RMSNormTP(CustomOp):
|
||||
name = "MiniMaxText01RMSNormTP"
|
||||
@@ -225,11 +222,10 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
self.tp_heads:(self.tp_rank + 1) *
|
||||
self.tp_heads].contiguous()
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
@staticmethod
|
||||
def weight_direct_load(param: torch.Tensor,
|
||||
@@ -268,8 +264,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
break
|
||||
if _prefill_idx >= len(state_indices_tensor):
|
||||
break
|
||||
# prefills are packed at end of batch in V1
|
||||
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
|
||||
offset = attn_metadata.num_decode_tokens
|
||||
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
|
||||
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
|
||||
slot_id = state_indices_tensor[offset + _prefill_idx]
|
||||
@@ -291,10 +286,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
hidden_decode = self._decode_infer(q, k, v, kv_cache,
|
||||
state_indices_tensor,
|
||||
attn_metadata)
|
||||
if envs.VLLM_USE_V1:
|
||||
hidden.insert(0, hidden_decode)
|
||||
else:
|
||||
hidden.append(hidden_decode)
|
||||
hidden.insert(0, hidden_decode)
|
||||
|
||||
if not hidden:
|
||||
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
||||
@@ -304,40 +296,28 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
|
||||
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
|
||||
attn_metadata):
|
||||
if not envs.VLLM_USE_V1:
|
||||
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
slot_id = state_indices_tensor[num_prefills:]
|
||||
else:
|
||||
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
|
||||
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
|
||||
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
|
||||
slot_id, 32)
|
||||
return hidden
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: MinimaxCacheParams) -> None:
|
||||
if not envs.VLLM_USE_V1:
|
||||
self._forward(hidden_states, output, positions, kv_caches)
|
||||
else:
|
||||
torch.ops.vllm.linear_attention(
|
||||
hidden_states,
|
||||
output,
|
||||
positions,
|
||||
self.prefix,
|
||||
)
|
||||
positions: torch.Tensor) -> None:
|
||||
torch.ops.vllm.linear_attention(
|
||||
hidden_states,
|
||||
output,
|
||||
positions,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[MinimaxCacheParams]) -> None:
|
||||
positions: torch.Tensor) -> None:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
if envs.VLLM_USE_V1 and attn_metadata is not None:
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
||||
@@ -351,32 +331,26 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
qkvact = torch.nn.functional.silu(qkv32)
|
||||
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
|
||||
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
||||
if envs.VLLM_USE_V1:
|
||||
if attn_metadata is not None:
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
if attn_metadata is not None:
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
if num_prefills > 0:
|
||||
num_decode_tokens = getattr(attn_metadata,
|
||||
"num_decode_tokens", 0)
|
||||
for prefill_idx in range(num_prefills):
|
||||
q_start = attn_metadata.query_start_loc[
|
||||
num_decode_tokens + prefill_idx]
|
||||
q_end = attn_metadata.query_start_loc[num_decode_tokens
|
||||
+ prefill_idx +
|
||||
1]
|
||||
query_len = q_end - q_start
|
||||
context_len = attn_metadata.seq_lens[
|
||||
num_decode_tokens + prefill_idx] - query_len
|
||||
if context_len == 0:
|
||||
block_to_clear = state_indices_tensor[
|
||||
num_decode_tokens + prefill_idx]
|
||||
kv_cache[block_to_clear, ...] = 0
|
||||
else:
|
||||
assert kv_caches is not None
|
||||
kv_cache = kv_caches.minimax_cache
|
||||
state_indices_tensor = kv_caches.state_indices_tensor
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
if num_prefills > 0:
|
||||
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens",
|
||||
0)
|
||||
for prefill_idx in range(num_prefills):
|
||||
q_start = attn_metadata.query_start_loc[num_decode_tokens +
|
||||
prefill_idx]
|
||||
q_end = attn_metadata.query_start_loc[num_decode_tokens +
|
||||
prefill_idx + 1]
|
||||
query_len = q_end - q_start
|
||||
context_len = attn_metadata.seq_lens[
|
||||
num_decode_tokens + prefill_idx] - query_len
|
||||
if context_len == 0:
|
||||
block_to_clear = state_indices_tensor[num_decode_tokens
|
||||
+ prefill_idx]
|
||||
kv_cache[block_to_clear, ...] = 0
|
||||
|
||||
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
|
||||
if attn_metadata is None:
|
||||
@@ -410,8 +384,7 @@ def linear_attention(
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self._forward(hidden_states=hidden_states,
|
||||
output=output,
|
||||
positions=positions,
|
||||
kv_caches=None)
|
||||
positions=positions)
|
||||
|
||||
|
||||
def linear_attention_fake(
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionMetadata)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||
from vllm.v1.attention.backends.mamba2_attn import (
|
||||
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba2Metadata:
|
||||
prep_initial_states: bool
|
||||
chunk_size: int
|
||||
|
||||
has_initial_states_p: torch.Tensor
|
||||
seq_idx_p: torch.Tensor
|
||||
chunk_indices_p: torch.Tensor
|
||||
chunk_offsets_p: torch.Tensor
|
||||
"""
|
||||
With continuous batching layout of `x` in vLLM, to enable a Triton program
|
||||
to handle a request in parallel, two supporting tensors are used
|
||||
(batch_ptr, token_chunk_offset_ptr)
|
||||
BLOCK_M = the # tokens to be handled by a Triton program
|
||||
(can be customized for different hardware)
|
||||
|
||||
nums_dict:
|
||||
tracks the data associated with a given value of BLOCK_M
|
||||
BLOCK_M = #tokens handled by a Triton program
|
||||
cu_seqlen: total tokens per batch
|
||||
(used as flag to update other data at each new input)
|
||||
batch_ptr: tracks batch-id handled by the Triton program
|
||||
token_chunk_offset_ptr: tracks token group_idx handled by the Triton program
|
||||
(Triton implementation of causal_conv1d handles parallelism in 3-axes
|
||||
- feature-axis
|
||||
- batch-axis
|
||||
- sequence-axis)
|
||||
"""
|
||||
nums_dict: Optional[dict] = None
|
||||
cu_seqlen: Optional[int] = None
|
||||
batch_ptr: Optional[torch.Tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
||||
"""Returns the appropriate metadata classes for the current platform."""
|
||||
if current_platform.is_rocm():
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
||||
AiterFlashAttentionMetadata)
|
||||
from vllm.v1.attention.backends.triton_attn import (
|
||||
TritonAttentionMetadata)
|
||||
return (AiterFlashAttentionMetadata, TritonAttentionMetadata,
|
||||
PlaceholderAttentionMetadata)
|
||||
if current_platform.is_cuda():
|
||||
from vllm.v1.attention.backends.flash_attn import (
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.attention.backends.xformers import (
|
||||
XFormersAttentionMetadata)
|
||||
return (FlashAttentionMetadata, XFormersAttentionMetadata,
|
||||
PlaceholderAttentionMetadata)
|
||||
raise ValueError(
|
||||
f"Unsupported platform for Mamba2: {current_platform.device_type}")
|
||||
|
||||
|
||||
def prepare_mamba2_metadata(
|
||||
chunk_size: int,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> Mamba2Metadata:
|
||||
|
||||
# compute number of prefill and decode requests
|
||||
# NOTE: in V0 we assume prefills are before decodes
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
|
||||
seq_idx_p = None
|
||||
chunk_indices_p, chunk_offsets_p = None, None
|
||||
# Need flags to indicate if there are initial states
|
||||
# currently we really only support the FlashAttention backend
|
||||
has_initial_states_p = None
|
||||
prep_initial_states = False
|
||||
|
||||
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
|
||||
if num_prefills > 0:
|
||||
attn_metadata_instances = get_platform_metadata_classes()
|
||||
if (isinstance(attn_metadata, attn_metadata_instances)
|
||||
and attn_metadata.context_lens_tensor is not None):
|
||||
# precompute flag to avoid device syncs later in mamba2 layer
|
||||
# forwards
|
||||
# prep is only needed for mamba2 ssd prefill processing
|
||||
has_initial_states_p = (
|
||||
attn_metadata.context_lens_tensor[:num_prefills] > 0)
|
||||
prep_initial_states = torch.any(has_initial_states_p).item()
|
||||
query_start_loc_p = attn_metadata.query_start_loc[:num_prefills + 1]
|
||||
seq_idx_p = torch.repeat_interleave(torch.arange(
|
||||
num_prefills, dtype=torch.int32, device=query_start_loc_p.device),
|
||||
query_start_loc_p.diff(),
|
||||
output_size=num_prefill_tokens)
|
||||
seq_idx_p.unsqueeze_(0)
|
||||
|
||||
# We compute metadata for chunked prefill once at the top level model
|
||||
# forward and reuse them in mamba layers. If not needed, they will be
|
||||
# ignored inside mamba kernels.
|
||||
if prep_initial_states:
|
||||
chunk_indices_p, chunk_offsets_p = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
query_start_loc_p, chunk_size, num_prefill_tokens)
|
||||
|
||||
return Mamba2Metadata(has_initial_states_p=has_initial_states_p,
|
||||
prep_initial_states=prep_initial_states,
|
||||
chunk_size=chunk_size,
|
||||
seq_idx_p=seq_idx_p,
|
||||
chunk_indices_p=chunk_indices_p,
|
||||
chunk_offsets_p=chunk_offsets_p)
|
||||
|
||||
|
||||
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
|
||||
mamba2_metadata: Union[Mamba2Metadata,
|
||||
Mamba2AttentionMetadata,
|
||||
GDNAttentionMetadata]):
|
||||
"""
|
||||
this is triggered upon handling a new input at the first layer
|
||||
"""
|
||||
dim, cu_seqlen = x.shape
|
||||
mamba2_metadata.cu_seqlen = cu_seqlen
|
||||
seqlens = np.diff(query_start_loc.to('cpu'))
|
||||
nums_dict = {} # type: ignore
|
||||
for BLOCK_M in [8]: # cover all BLOCK_M values
|
||||
nums = -(-seqlens // BLOCK_M)
|
||||
nums_dict[BLOCK_M] = {}
|
||||
nums_dict[BLOCK_M]['nums'] = nums
|
||||
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
|
||||
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
|
||||
nums_dict[BLOCK_M]['mlist'] = mlist
|
||||
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
|
||||
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
|
||||
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
|
||||
offsetlist = [] # type: ignore
|
||||
for idx, num in enumerate(nums):
|
||||
offsetlist.extend(range(num))
|
||||
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
|
||||
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
|
||||
|
||||
if mamba2_metadata.batch_ptr is None:
|
||||
# Update default value after class definition
|
||||
#mamba2_metadata.MAX_NUM_PROGRAMS *= 2
|
||||
mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
mamba2_metadata.token_chunk_offset_ptr = torch.full(
|
||||
(MAX_NUM_PROGRAMS, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
else:
|
||||
if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS:
|
||||
mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(
|
||||
PAD_SLOT_ID)
|
||||
mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore
|
||||
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
|
||||
|
||||
mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist)
|
||||
mamba2_metadata.token_chunk_offset_ptr[ # type: ignore
|
||||
0:mlist_len].copy_(offsetlist)
|
||||
nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr
|
||||
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (
|
||||
mamba2_metadata.token_chunk_offset_ptr) # type: ignore
|
||||
mamba2_metadata.nums_dict = nums_dict
|
||||
return mamba2_metadata
|
||||
@@ -10,8 +10,6 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
@@ -28,7 +26,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
@@ -149,16 +146,12 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
has_weight=rms_norm_has_weight,
|
||||
) if use_rms_norm else None
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The outer list is for v0 PP virtual engine. Though this code path
|
||||
# only runs for v1, we have to do this to unify with the interface
|
||||
# of Attention + v0 PP.
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
@@ -186,29 +179,18 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||
return discrete_time_step, B, C
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
if not envs.VLLM_USE_V1:
|
||||
CustomOp.forward(self, hidden_states, output, mamba_cache_params)
|
||||
else:
|
||||
torch.ops.vllm.mamba_mixer(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
)
|
||||
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||
torch.ops.vllm.mamba_mixer(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def forward_native(self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
def forward_native(self, hidden_states: torch.Tensor,
|
||||
output: torch.Tensor):
|
||||
pass
|
||||
|
||||
def forward_cuda(self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||
"""
|
||||
Run the Mamba-1 SSM pipeline.
|
||||
|
||||
@@ -234,31 +216,18 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
mamba1_metadata = attn_metadata
|
||||
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
|
||||
query_start_loc = mamba1_metadata.query_start_loc
|
||||
state_indices_tensor = mamba1_metadata.state_indices_tensor
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states = mamba1_metadata.has_initial_states
|
||||
num_padded_decodes = mamba1_metadata.num_padded_decodes
|
||||
else:
|
||||
assert isinstance(attn_metadata, AttentionMetadata)
|
||||
assert mamba_cache_params is not None
|
||||
conv_state = mamba_cache_params.conv_state
|
||||
ssm_state = mamba_cache_params.ssm_state
|
||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||
query_start_loc = attn_metadata.query_start_loc
|
||||
context_lens_tensor = attn_metadata.context_lens_tensor
|
||||
has_initial_states = None
|
||||
if context_lens_tensor is not None:
|
||||
has_initial_states = context_lens_tensor > 0
|
||||
num_padded_decodes = attn_metadata.num_decode_tokens
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
mamba1_metadata = attn_metadata
|
||||
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
|
||||
query_start_loc = mamba1_metadata.query_start_loc
|
||||
state_indices_tensor = mamba1_metadata.state_indices_tensor
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states = mamba1_metadata.has_initial_states
|
||||
num_padded_decodes = mamba1_metadata.num_padded_decodes
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
@@ -267,7 +236,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
if envs.VLLM_USE_V1 and attn_metadata is None:
|
||||
if attn_metadata is None:
|
||||
# V1 profile run
|
||||
hidden_states_BC = hidden_states_BC.contiguous()
|
||||
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
|
||||
@@ -368,10 +337,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
out=scan_outputs_d)
|
||||
scan_outputs_d = scan_outputs_d.transpose(0, 1)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
ssm_outputs.insert(0, scan_outputs_d)
|
||||
else:
|
||||
ssm_outputs.append(scan_outputs_d)
|
||||
ssm_outputs.insert(0, scan_outputs_d)
|
||||
|
||||
scan_outputs_combined = ssm_outputs[0] if len(
|
||||
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
|
||||
@@ -441,40 +407,27 @@ def split_batch_to_prefill_and_decode(
|
||||
num_decodes: int,
|
||||
num_padded_decodes: int,
|
||||
) -> PrefillDecodeSplit:
|
||||
|
||||
num_actual_tokens = num_prefill_tokens + num_padded_decodes
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
# In v1, decode tokens come first, then prefill tokens.
|
||||
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||
hidden_states_BC[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
gate_d, gate_p = torch.split(gate[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
# In v1, decode tokens come first, then prefill tokens.
|
||||
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||
hidden_states_BC[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
gate_d, gate_p = torch.split(gate[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
|
||||
# num_padded_decodes accounts for CUDA graph padding when applicable
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[:num_padded_decodes + num_prefills],
|
||||
[num_padded_decodes, num_prefills],
|
||||
dim=0)
|
||||
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
|
||||
num_padded_decodes if num_prefills > 0 else None)
|
||||
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
||||
has_initial_states is not None and num_prefills > 0) else None
|
||||
else:
|
||||
# In v0, prefill tokens come first, then decode tokens.
|
||||
hidden_states_BC_p, hidden_states_BC_d = torch.split(
|
||||
hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1)
|
||||
gate_p, gate_d = torch.split(gate,
|
||||
[num_prefill_tokens, num_decode_tokens],
|
||||
dim=-1)
|
||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||
state_indices_tensor, [num_prefills, num_decodes], dim=0)
|
||||
query_start_loc_p = (query_start_loc[:num_prefills +
|
||||
1] if num_prefills > 0 else None)
|
||||
has_initial_states_p = has_initial_states[:num_prefills] if (
|
||||
has_initial_states is not None and num_prefills > 0) else None
|
||||
# num_padded_decodes accounts for CUDA graph padding when applicable
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[:num_padded_decodes + num_prefills],
|
||||
[num_padded_decodes, num_prefills],
|
||||
dim=0)
|
||||
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
|
||||
num_padded_decodes if num_prefills > 0 else None)
|
||||
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
||||
has_initial_states is not None and num_prefills > 0) else None
|
||||
|
||||
return PrefillDecodeSplit(
|
||||
hidden_states_BC_p=hidden_states_BC_p,
|
||||
@@ -495,9 +448,7 @@ def mamba_mixer(
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states,
|
||||
output=output,
|
||||
mamba_cache_params=None)
|
||||
self.forward_cuda(hidden_states=hidden_states, output=output)
|
||||
|
||||
|
||||
def mamba_mixer_fake(
|
||||
|
||||
@@ -9,7 +9,6 @@ if TYPE_CHECKING:
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
@@ -22,8 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
||||
update_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
@@ -36,7 +33,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
@@ -449,16 +445,12 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
self.use_rms_norm,
|
||||
eps=rms_norm_eps)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The outer list is for v0 PP virtual engine. Though this code path
|
||||
# only runs for v1, we have to do this to unify with the interface
|
||||
# of Attention + v0 PP.
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
@@ -468,8 +460,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pass
|
||||
@@ -478,59 +468,43 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if not envs.VLLM_USE_V1:
|
||||
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
|
||||
mamba2_metadata, mup_vector)
|
||||
else:
|
||||
torch.ops.vllm.mamba_mixer2(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
mup_vector,
|
||||
)
|
||||
torch.ops.vllm.mamba_mixer2(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
mup_vector,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
||||
# attn_metadata contains metadata necessary for the mamba2 triton
|
||||
# kernels to operate in continuous batching and in chunked prefill
|
||||
# modes; they are computed at top-level model forward since they
|
||||
# stay the same and reused for all mamba layers in the same iteration
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
if envs.VLLM_USE_V1:
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
mamba2_metadata = attn_metadata
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
else:
|
||||
conv_state = mamba_cache_params.conv_state
|
||||
ssm_state = mamba_cache_params.ssm_state
|
||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||
|
||||
# Common members between V1 metadata and V0 metadata
|
||||
if mamba2_metadata is not None:
|
||||
has_initial_states_p = mamba2_metadata.has_initial_states_p
|
||||
prep_initial_states = mamba2_metadata.prep_initial_states
|
||||
chunk_size = mamba2_metadata.chunk_size
|
||||
seq_idx_p = mamba2_metadata.seq_idx_p
|
||||
chunk_indices_p = mamba2_metadata.chunk_indices_p
|
||||
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
prep_initial_states = attn_metadata.prep_initial_states
|
||||
chunk_size = attn_metadata.chunk_size
|
||||
seq_idx_p = attn_metadata.seq_idx_p
|
||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states, _ = self.in_proj(hidden_states)
|
||||
@@ -562,8 +536,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_V1 and attn_metadata is None:
|
||||
# V1 profile run
|
||||
if attn_metadata is None:
|
||||
# profile run
|
||||
hidden_states_B_C = (hidden_states_B_C.transpose(
|
||||
0, 1).clone().transpose(0, 1)).contiguous()
|
||||
hidden_states, _B, _C = split_hidden_states_B_C_fn(
|
||||
@@ -579,49 +553,27 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
has_decode = num_decodes > 0
|
||||
num_actual_tokens = num_prefill_tokens + num_decodes
|
||||
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
if envs.VLLM_USE_V1:
|
||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||
hidden_states_B_C[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
dt_d, dt_p = torch.split(
|
||||
dt[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[:num_actual_tokens],
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if has_prefill else None)
|
||||
else:
|
||||
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
|
||||
hidden_states_B_C,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
dt_p, dt_d = torch.split(
|
||||
dt,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||
state_indices_tensor,
|
||||
[num_prefills, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
|
||||
1]
|
||||
if has_prefill else None)
|
||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||
hidden_states_B_C[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
dt_d, dt_p = torch.split(
|
||||
dt[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[:num_actual_tokens],
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if has_prefill else None)
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
@@ -633,18 +585,11 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
if envs.VLLM_USE_V1:
|
||||
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Process prefill requests
|
||||
if has_prefill:
|
||||
@@ -653,9 +598,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# pointed to by "state_indices_tensor"
|
||||
x = hidden_states_B_C_p.transpose(
|
||||
0, 1) # this is the form that causal-conv see
|
||||
if mamba2_metadata.cu_seqlen is None:
|
||||
mamba2_metadata = update_metadata(x, query_start_loc_p,
|
||||
mamba2_metadata)
|
||||
hidden_states_B_C_p = causal_conv1d_fn(
|
||||
x,
|
||||
conv_weights,
|
||||
@@ -664,7 +606,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
metadata=mamba2_metadata,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
|
||||
@@ -806,8 +748,6 @@ def mamba_mixer2(
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states,
|
||||
output=output,
|
||||
mamba_cache_params=None,
|
||||
mamba2_metadata=None,
|
||||
mup_vector=mup_vector)
|
||||
|
||||
|
||||
|
||||
@@ -100,7 +100,6 @@ class MambaStateShapeCalculator:
|
||||
intermediate_size: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
conv_state_shape = (divide(intermediate_size,
|
||||
tp_world_size), conv_kernel - 1)
|
||||
@@ -108,11 +107,7 @@ class MambaStateShapeCalculator:
|
||||
temporal_state_shape = (divide(intermediate_size,
|
||||
tp_world_size), state_size)
|
||||
|
||||
# In V0, the conv_state shape was swapped during allocation in
|
||||
# MambaCacheManager, but in V1 it needs to be determined here at the
|
||||
# calculation level
|
||||
if use_v1:
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@@ -126,7 +121,6 @@ class MambaStateShapeCalculator:
|
||||
head_dim: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
@@ -137,8 +131,6 @@ class MambaStateShapeCalculator:
|
||||
|
||||
# contiguous along 'dim' axis
|
||||
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
|
||||
if not use_v1:
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
@@ -153,12 +145,9 @@ class MambaStateShapeCalculator:
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
conv_kernel: int,
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int]]:
|
||||
conv_dim = divide(intermediate_size, tp_world_size)
|
||||
conv_state_shape = (conv_kernel - 1, conv_dim)
|
||||
if not use_v1:
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
return (conv_state_shape, )
|
||||
|
||||
@classmethod
|
||||
@@ -183,7 +172,6 @@ class MambaStateShapeCalculator:
|
||||
head_v_dim: int,
|
||||
conv_kernel_size: int,
|
||||
num_spec: int = 0,
|
||||
use_v1: bool = True,
|
||||
):
|
||||
conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads)
|
||||
conv_state_shape = (
|
||||
@@ -191,11 +179,7 @@ class MambaStateShapeCalculator:
|
||||
conv_kernel_size - 1 + num_spec,
|
||||
)
|
||||
|
||||
# In V0, the conv_state shape was swapped during allocation in
|
||||
# MambaCacheManager, but in V1 it needs to be determined here at the
|
||||
# calculation level
|
||||
if use_v1:
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
temporal_state_shape = (divide(num_v_heads,
|
||||
tp_world_size), head_k_dim, head_v_dim)
|
||||
|
||||
@@ -420,9 +420,7 @@ def causal_conv1d_fn(
|
||||
x = x.to(conv_states.dtype)
|
||||
out = torch.empty_like(x)
|
||||
if metadata is not None:
|
||||
cu_seqlen = metadata.cu_seqlen
|
||||
nums_dict = metadata.nums_dict
|
||||
#x = metadata.x
|
||||
args = nums_dict
|
||||
batch_ptr = metadata.batch_ptr
|
||||
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
|
||||
@@ -926,7 +924,6 @@ def causal_conv1d_update(
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
max_query_len: int = -1,
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
metadata=None,
|
||||
validate_data=False,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -8,7 +8,6 @@ if TYPE_CHECKING:
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@@ -18,7 +17,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
@@ -71,15 +69,11 @@ class ShortConv(MambaBase, CustomOp):
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
assert envs.VLLM_USE_V1, ("ShortConv layers are only supported in V1")
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The outer list is for v0 PP virtual engine. Though this code path
|
||||
# only runs for v1, we have to do this to unify with the interface
|
||||
# of Attention + v0 PP.
|
||||
self.kv_cache = [(torch.tensor([]), )]
|
||||
self.kv_cache = (torch.tensor([]), )
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
@@ -89,7 +83,6 @@ class ShortConv(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
conv_metadata: ShortConvAttentionMetadata,
|
||||
):
|
||||
return
|
||||
|
||||
@@ -97,7 +90,6 @@ class ShortConv(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
conv_metadata: ShortConvAttentionMetadata,
|
||||
):
|
||||
torch.ops.vllm.short_conv(
|
||||
hidden_states,
|
||||
@@ -109,7 +101,6 @@ class ShortConv(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
conv_metadata: ShortConvAttentionMetadata,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
# ShortConvAttentionMetadata contains metadata necessary for the
|
||||
@@ -121,7 +112,6 @@ class ShortConv(MambaBase, CustomOp):
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
conv_metadata = attn_metadata
|
||||
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
@@ -181,9 +171,6 @@ class ShortConv(MambaBase, CustomOp):
|
||||
|
||||
if has_prefill:
|
||||
Bx_p = (B_p * x_p).transpose(0, 1)
|
||||
if conv_metadata.cu_seqlen is None:
|
||||
conv_metadata = update_metadata(Bx_p, query_start_loc_p,
|
||||
conv_metadata)
|
||||
Bx = causal_conv1d_fn(Bx_p,
|
||||
conv_weights,
|
||||
self.conv.bias,
|
||||
@@ -191,7 +178,7 @@ class ShortConv(MambaBase, CustomOp):
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
metadata=conv_metadata,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
|
||||
@@ -248,9 +235,7 @@ def short_conv(
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states,
|
||||
output=output,
|
||||
conv_metadata=None)
|
||||
self.forward_cuda(hidden_states=hidden_states, output=output)
|
||||
|
||||
|
||||
def short_conv_fake(
|
||||
|
||||
Reference in New Issue
Block a user