diff --git a/tests/test_o_projection_b200.py b/tests/test_o_projection_b200.py index 1640db0d..dc51e85c 100644 --- a/tests/test_o_projection_b200.py +++ b/tests/test_o_projection_b200.py @@ -33,9 +33,9 @@ HEAD_DIM = 512 NOPE_DIM = 448 ROPE_DIM = 64 Q_LORA_RANK = 1536 -O_LORA_RANK = 1536 -O_GROUPS = 8 -HEADS_PER_GROUP = NUM_HEADS // O_GROUPS # 16 +O_LORA_RANK = 1024 +O_GROUPS = 16 # from config (not TP-sharded) +HEADS_PER_GROUP = NUM_HEADS // O_GROUPS # 8 NUM_TOKENS = 4 _cache = {} @@ -122,7 +122,9 @@ def new_path_o_projection(o, positions, cos_sin_cache, wo_a_weight_bf16): # Step 2: wo_a BMM num_tokens = o_inv.shape[0] - hidden_dim = HEADS_PER_GROUP * HEAD_DIM # 8192 + # wo_a weight: (O_GROUPS * O_LORA_RANK, HEADS_PER_GROUP * HEAD_DIM) + hidden_dim = wo_a_weight_bf16.shape[1] # 4096 = HEADS_PER_GROUP * HEAD_DIM + out_dim = wo_a_weight_bf16.shape[0] # 16384 = O_GROUPS * O_LORA_RANK o_grouped = o_inv.view(num_tokens, O_GROUPS, hidden_dim) wo_a_w = wo_a_weight_bf16.view(O_GROUPS, O_LORA_RANK, hidden_dim) z = torch.bmm(