[V0 Deprecation] Refactor kv cache from list to element (#37487)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-23 23:10:11 -04:00
committed by GitHub
parent de99d91ece
commit c59a132f96
27 changed files with 70 additions and 85 deletions

View File

@@ -148,7 +148,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_cache = raw_tensor.permute(*inv_order)
self.attn.kv_cache = [kv_cache]
self.attn.kv_cache = kv_cache
# Build attn metadata
attn_metadata = self.builder.build(
@@ -295,7 +295,7 @@ def test_rope_kvcache_fusion(
}
q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused)
attn_layer = forward_context.no_compile_layers[model.layer_name]
kv_cache_unfused = attn_layer.kv_cache[0]
kv_cache_unfused = attn_layer.kv_cache
del dummy
torch._dynamo.mark_dynamic(qkv, 0)
@@ -309,7 +309,7 @@ def test_rope_kvcache_fusion(
}
q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos)
attn_layer = forward_context.no_compile_layers[model.layer_name]
kv_cache_fused = attn_layer.kv_cache[0]
kv_cache_fused = attn_layer.kv_cache
del dummy
assert fusion_pass.matched_count == 1