diff --git a/tests/test_pipeline_real_weights.py b/tests/test_pipeline_real_weights.py index 570c6ac0..4815fe3e 100644 --- a/tests/test_pipeline_real_weights.py +++ b/tests/test_pipeline_real_weights.py @@ -180,7 +180,7 @@ def test_stage2_l1_gemm(slot_hidden, expert_offsets, nvfp4_tensors, layer_idx, e # Stack weights for GEMM l1_mat_b = torch.stack(weights['l1_fp4']) l1_scale_b = torch.stack(weights['l1_sf']) - l1_gsb = torch.stack(weights['l1_gs']) + l1_gsb = torch.tensor(weights['l1_gs'], dtype=torch.float32, device=DEVICE) # Make B-K major l1_mat_b = make_b_k_major(l1_mat_b) @@ -283,7 +283,7 @@ def test_stage4_l2_gemm(activated, expert_offsets, nvfp4_tensors, layer_idx, exp l2_mat_b = torch.stack(weights['l2_fp4']) l2_scale_b = torch.stack(weights['l2_sf']) - l2_gsb = torch.stack(weights['l2_gs']) + l2_gsb = torch.tensor(weights['l2_gs'], dtype=torch.float32, device=DEVICE) l2_mat_b = make_b_k_major(l2_mat_b) l2_scale_b = assemble_scales_3d_side(l2_scale_b)