[V0 Deprecation] Refactor kv cache from list to element (#37487)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -148,7 +148,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
|
||||
raw_tensor = raw_tensor.view(kv_cache_shape)
|
||||
kv_cache = raw_tensor.permute(*inv_order)
|
||||
|
||||
self.attn.kv_cache = [kv_cache]
|
||||
self.attn.kv_cache = kv_cache
|
||||
|
||||
# Build attn metadata
|
||||
attn_metadata = self.builder.build(
|
||||
@@ -295,7 +295,7 @@ def test_rope_kvcache_fusion(
|
||||
}
|
||||
q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused)
|
||||
attn_layer = forward_context.no_compile_layers[model.layer_name]
|
||||
kv_cache_unfused = attn_layer.kv_cache[0]
|
||||
kv_cache_unfused = attn_layer.kv_cache
|
||||
del dummy
|
||||
|
||||
torch._dynamo.mark_dynamic(qkv, 0)
|
||||
@@ -309,7 +309,7 @@ def test_rope_kvcache_fusion(
|
||||
}
|
||||
q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos)
|
||||
attn_layer = forward_context.no_compile_layers[model.layer_name]
|
||||
kv_cache_fused = attn_layer.kv_cache[0]
|
||||
kv_cache_fused = attn_layer.kv_cache
|
||||
del dummy
|
||||
|
||||
assert fusion_pass.matched_count == 1
|
||||
|
||||
Reference in New Issue
Block a user