Fix: gs values are floats not tensors

This commit is contained in:
2026-05-17 21:19:47 +00:00
parent b05a38a9bd
commit 6eade5e7f8

View File

@@ -180,7 +180,7 @@ def test_stage2_l1_gemm(slot_hidden, expert_offsets, nvfp4_tensors, layer_idx, e
# Stack weights for GEMM
l1_mat_b = torch.stack(weights['l1_fp4'])
l1_scale_b = torch.stack(weights['l1_sf'])
l1_gsb = torch.stack(weights['l1_gs'])
l1_gsb = torch.tensor(weights['l1_gs'], dtype=torch.float32, device=DEVICE)
# Make B-K major
l1_mat_b = make_b_k_major(l1_mat_b)
@@ -283,7 +283,7 @@ def test_stage4_l2_gemm(activated, expert_offsets, nvfp4_tensors, layer_idx, exp
l2_mat_b = torch.stack(weights['l2_fp4'])
l2_scale_b = torch.stack(weights['l2_sf'])
l2_gsb = torch.stack(weights['l2_gs'])
l2_gsb = torch.tensor(weights['l2_gs'], dtype=torch.float32, device=DEVICE)
l2_mat_b = make_b_k_major(l2_mat_b)
l2_scale_b = assemble_scales_3d_side(l2_scale_b)