Fix dims: o_groups=16, o_lora_rank=1024 from config
This commit is contained in:
@@ -33,9 +33,9 @@ HEAD_DIM = 512
|
||||
NOPE_DIM = 448
|
||||
ROPE_DIM = 64
|
||||
Q_LORA_RANK = 1536
|
||||
O_LORA_RANK = 1536
|
||||
O_GROUPS = 8
|
||||
HEADS_PER_GROUP = NUM_HEADS // O_GROUPS # 16
|
||||
O_LORA_RANK = 1024
|
||||
O_GROUPS = 16 # from config (not TP-sharded)
|
||||
HEADS_PER_GROUP = NUM_HEADS // O_GROUPS # 8
|
||||
NUM_TOKENS = 4
|
||||
|
||||
_cache = {}
|
||||
@@ -122,7 +122,9 @@ def new_path_o_projection(o, positions, cos_sin_cache, wo_a_weight_bf16):
|
||||
|
||||
# Step 2: wo_a BMM
|
||||
num_tokens = o_inv.shape[0]
|
||||
hidden_dim = HEADS_PER_GROUP * HEAD_DIM # 8192
|
||||
# wo_a weight: (O_GROUPS * O_LORA_RANK, HEADS_PER_GROUP * HEAD_DIM)
|
||||
hidden_dim = wo_a_weight_bf16.shape[1] # 4096 = HEADS_PER_GROUP * HEAD_DIM
|
||||
out_dim = wo_a_weight_bf16.shape[0] # 16384 = O_GROUPS * O_LORA_RANK
|
||||
o_grouped = o_inv.view(num_tokens, O_GROUPS, hidden_dim)
|
||||
wo_a_w = wo_a_weight_bf16.view(O_GROUPS, O_LORA_RANK, hidden_dim)
|
||||
z = torch.bmm(
|
||||
|
||||
Reference in New Issue
Block a user