diff --git a/tests/test_csa_sparse_attn_b200.py b/tests/test_csa_sparse_attn_b200.py index 507d027e..5f6566c4 100644 --- a/tests/test_csa_sparse_attn_b200.py +++ b/tests/test_csa_sparse_attn_b200.py @@ -216,8 +216,8 @@ def test_csa_layer(layer_id, compress_ratio): wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2") # Compressor weights - comp_kv_w = G(f"{a}.comp.kv_proj.weight"); comp_kv_sf = G(f"{a}.comp.kv_proj.weight_scale"); comp_kv_gs = G(f"{a}.comp.kv_proj.weight_scale_2") - comp_gate_w = G(f"{a}.comp.gate_proj.weight"); comp_gate_sf = G(f"{a}.comp.gate_proj.weight_scale"); comp_gate_gs = G(f"{a}.comp.gate_proj.weight_scale_2") + comp_kv_w = G(f"{a}.compressor.kv_proj.weight"); comp_kv_sf = G(f"{a}.compressor.kv_proj.weight_scale"); comp_kv_gs = G(f"{a}.compressor.kv_proj.weight_scale_2") + comp_gate_w = G(f"{a}.compressor.gate_proj.weight"); comp_gate_sf = G(f"{a}.compressor.gate_proj.weight_scale"); comp_gate_gs = G(f"{a}.compressor.gate_proj.weight_scale_2") r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0]) r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0])