Fix remaining mHC API references: layer_compare.py, layer.py comment
This commit is contained in:
@@ -41,7 +41,7 @@ class TransformerLayer:
|
||||
self.layer_idx = spec.layer_idx
|
||||
|
||||
# Two mHC wrappers — one per sub-block. mHCLayer holds its own
|
||||
# projection weights (W_pre, W_res, W_post) and static biases.
|
||||
# projection weights (W_pre, W_post, W_comb) and static biases.
|
||||
self.mhc_attn = mHCLayer(
|
||||
hidden_dim=config.hidden_size,
|
||||
n_hc=config.n_hc,
|
||||
|
||||
@@ -89,14 +89,14 @@ def main():
|
||||
n = n_hc
|
||||
attn_mhc.load_weights(
|
||||
W_pre=fn[0:n].to(device, dtype=torch.float32),
|
||||
W_res=fn[n:n+n*n].to(device, dtype=torch.float32),
|
||||
W_post=fn[n+n*n:].to(device, dtype=torch.float32),
|
||||
W_post=fn[n:2*n].to(device, dtype=torch.float32),
|
||||
W_comb=fn[2*n:].to(device, dtype=torch.float32),
|
||||
S_pre=base[0:n].reshape(1, n).to(device, dtype=torch.bfloat16),
|
||||
S_res=base[n:n+n*n].reshape(n, n).to(device, dtype=torch.bfloat16),
|
||||
S_post=base[n+n*n:].reshape(n, 1).to(device, dtype=torch.bfloat16),
|
||||
S_post=base[n:2*n].reshape(n, 1).to(device, dtype=torch.bfloat16),
|
||||
S_comb=base[2*n:].reshape(n, n).to(device, dtype=torch.bfloat16),
|
||||
alpha_pre=scale[0].item(),
|
||||
alpha_res=scale[1].item(),
|
||||
alpha_post=scale[2].item(),
|
||||
alpha_post=scale[1].item(),
|
||||
alpha_comb=scale[2].item(),
|
||||
)
|
||||
|
||||
# === OUR IMPLEMENTATION (single_shot_inference) ===
|
||||
|
||||
Reference in New Issue
Block a user