Fix remaining mHC API references: layer_compare.py, layer.py comment

This commit is contained in:
2026-05-31 18:38:34 +00:00
parent 7b123d159f
commit 7d9e70c5d5
2 changed files with 7 additions and 7 deletions

View File

@@ -41,7 +41,7 @@ class TransformerLayer:
self.layer_idx = spec.layer_idx
# Two mHC wrappers — one per sub-block. mHCLayer holds its own
# projection weights (W_pre, W_res, W_post) and static biases.
# projection weights (W_pre, W_post, W_comb) and static biases.
self.mhc_attn = mHCLayer(
hidden_dim=config.hidden_size,
n_hc=config.n_hc,

View File

@@ -89,14 +89,14 @@ def main():
n = n_hc
attn_mhc.load_weights(
W_pre=fn[0:n].to(device, dtype=torch.float32),
W_res=fn[n:n+n*n].to(device, dtype=torch.float32),
W_post=fn[n+n*n:].to(device, dtype=torch.float32),
W_post=fn[n:2*n].to(device, dtype=torch.float32),
W_comb=fn[2*n:].to(device, dtype=torch.float32),
S_pre=base[0:n].reshape(1, n).to(device, dtype=torch.bfloat16),
S_res=base[n:n+n*n].reshape(n, n).to(device, dtype=torch.bfloat16),
S_post=base[n+n*n:].reshape(n, 1).to(device, dtype=torch.bfloat16),
S_post=base[n:2*n].reshape(n, 1).to(device, dtype=torch.bfloat16),
S_comb=base[2*n:].reshape(n, n).to(device, dtype=torch.bfloat16),
alpha_pre=scale[0].item(),
alpha_res=scale[1].item(),
alpha_post=scale[2].item(),
alpha_post=scale[1].item(),
alpha_comb=scale[2].item(),
)
# === OUR IMPLEMENTATION (single_shot_inference) ===