[V0 Deprecation] Deprecate virtual engine (#37195)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -589,7 +589,7 @@ def get_attention_context(
|
||||
- attn_metadata: Attention metadata for this specific layer, or None if
|
||||
no metadata available
|
||||
- attn_layer: The attention layer instance (Attention or MLAAttention)
|
||||
- kv_cache: The KV cache tensor for current virtual engine
|
||||
- kv_cache: The KV cache tensor for current forward pass
|
||||
- slot_mapping: The slot mapping for this specific layer
|
||||
|
||||
Note: attn_metadata may be None, but attn_layer and kv_cache are always
|
||||
@@ -600,7 +600,7 @@ def get_attention_context(
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||
kv_cache = attn_layer.kv_cache[0]
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||
|
||||
@@ -480,7 +480,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self_kv_cache = self.kv_cache[0]
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
@@ -940,7 +940,7 @@ def unified_mla_kv_cache_update(
|
||||
return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)
|
||||
|
||||
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||
kv_cache = attn_layer.kv_cache[0]
|
||||
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
|
||||
@@ -168,8 +168,7 @@ class StaticSinkAttention(Attention, CustomOp):
|
||||
"sink_key and sink_value have not been prepared"
|
||||
)
|
||||
if not self.sink_populated:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self_kv_cache = self.kv_cache[0]
|
||||
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)
|
||||
|
||||
return super().forward(query, key, value, output_shape)
|
||||
|
||||
@@ -306,7 +306,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
constant_caches = self.kv_cache[forward_context.virtual_engine]
|
||||
constant_caches = self.kv_cache[0]
|
||||
|
||||
q_proj_states = q_proj_states[:num_actual_tokens]
|
||||
k_proj_states = k_proj_states[:num_actual_tokens]
|
||||
|
||||
@@ -413,7 +413,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
|
||||
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
||||
if attn_metadata is not None:
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
||||
kv_cache = self.kv_cache[0][0]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
clear_linear_attention_cache_for_new_sequences(
|
||||
kv_cache, state_indices_tensor, attn_metadata
|
||||
|
||||
@@ -267,7 +267,7 @@ class MambaMixer(MambaBase, PluggableLayer):
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
|
||||
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self_kv_cache = self.kv_cache[0]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
|
||||
@@ -575,7 +575,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self_kv_cache = self.kv_cache[0]
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
|
||||
@@ -117,7 +117,7 @@ class ShortConv(MambaBase, CustomOp):
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self_kv_cache = self.kv_cache[0]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
|
||||
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
|
||||
|
||||
Reference in New Issue
Block a user