Fix dims: o_groups=16, o_lora_rank=1024 from config

This commit is contained in:
2026-05-19 06:37:25 +00:00
parent b4fee70151
commit 199efe0871

View File

@@ -33,9 +33,9 @@ HEAD_DIM = 512
NOPE_DIM = 448
ROPE_DIM = 64
Q_LORA_RANK = 1536
O_LORA_RANK = 1536
O_GROUPS = 8
HEADS_PER_GROUP = NUM_HEADS // O_GROUPS # 16
O_LORA_RANK = 1024
O_GROUPS = 16 # from config (not TP-sharded)
HEADS_PER_GROUP = NUM_HEADS // O_GROUPS # 8
NUM_TOKENS = 4
_cache = {}
@@ -122,7 +122,9 @@ def new_path_o_projection(o, positions, cos_sin_cache, wo_a_weight_bf16):
# Step 2: wo_a BMM
num_tokens = o_inv.shape[0]
hidden_dim = HEADS_PER_GROUP * HEAD_DIM # 8192
# wo_a weight: (O_GROUPS * O_LORA_RANK, HEADS_PER_GROUP * HEAD_DIM)
hidden_dim = wo_a_weight_bf16.shape[1] # 4096 = HEADS_PER_GROUP * HEAD_DIM
out_dim = wo_a_weight_bf16.shape[0] # 16384 = O_GROUPS * O_LORA_RANK
o_grouped = o_inv.view(num_tokens, O_GROUPS, hidden_dim)
wo_a_w = wo_a_weight_bf16.view(O_GROUPS, O_LORA_RANK, hidden_dim)
z = torch.bmm(