diff --git a/dsv4/model/layer.py b/dsv4/model/layer.py index 388e5eb7..77b7650d 100644 --- a/dsv4/model/layer.py +++ b/dsv4/model/layer.py @@ -41,7 +41,7 @@ class TransformerLayer: self.layer_idx = spec.layer_idx # Two mHC wrappers — one per sub-block. mHCLayer holds its own - # projection weights (W_pre, W_res, W_post) and static biases. + # projection weights (W_pre, W_post, W_comb) and static biases. self.mhc_attn = mHCLayer( hidden_dim=config.hidden_size, n_hc=config.n_hc, diff --git a/tests/layer_compare.py b/tests/layer_compare.py index a47bba9d..075ccfb2 100644 --- a/tests/layer_compare.py +++ b/tests/layer_compare.py @@ -89,14 +89,14 @@ def main(): n = n_hc attn_mhc.load_weights( W_pre=fn[0:n].to(device, dtype=torch.float32), - W_res=fn[n:n+n*n].to(device, dtype=torch.float32), - W_post=fn[n+n*n:].to(device, dtype=torch.float32), + W_post=fn[n:2*n].to(device, dtype=torch.float32), + W_comb=fn[2*n:].to(device, dtype=torch.float32), S_pre=base[0:n].reshape(1, n).to(device, dtype=torch.bfloat16), - S_res=base[n:n+n*n].reshape(n, n).to(device, dtype=torch.bfloat16), - S_post=base[n+n*n:].reshape(n, 1).to(device, dtype=torch.bfloat16), + S_post=base[n:2*n].reshape(n, 1).to(device, dtype=torch.bfloat16), + S_comb=base[2*n:].reshape(n, n).to(device, dtype=torch.bfloat16), alpha_pre=scale[0].item(), - alpha_res=scale[1].item(), - alpha_post=scale[2].item(), + alpha_post=scale[1].item(), + alpha_comb=scale[2].item(), ) # === OUR IMPLEMENTATION (single_shot_inference) ===