Fix: gs values are floats not tensors
This commit is contained in:
@@ -180,7 +180,7 @@ def test_stage2_l1_gemm(slot_hidden, expert_offsets, nvfp4_tensors, layer_idx, e
|
||||
# Stack weights for GEMM
|
||||
l1_mat_b = torch.stack(weights['l1_fp4'])
|
||||
l1_scale_b = torch.stack(weights['l1_sf'])
|
||||
l1_gsb = torch.stack(weights['l1_gs'])
|
||||
l1_gsb = torch.tensor(weights['l1_gs'], dtype=torch.float32, device=DEVICE)
|
||||
|
||||
# Make B-K major
|
||||
l1_mat_b = make_b_k_major(l1_mat_b)
|
||||
@@ -283,7 +283,7 @@ def test_stage4_l2_gemm(activated, expert_offsets, nvfp4_tensors, layer_idx, exp
|
||||
|
||||
l2_mat_b = torch.stack(weights['l2_fp4'])
|
||||
l2_scale_b = torch.stack(weights['l2_sf'])
|
||||
l2_gsb = torch.stack(weights['l2_gs'])
|
||||
l2_gsb = torch.tensor(weights['l2_gs'], dtype=torch.float32, device=DEVICE)
|
||||
|
||||
l2_mat_b = make_b_k_major(l2_mat_b)
|
||||
l2_scale_b = assemble_scales_3d_side(l2_scale_b)
|
||||
|
||||
Reference in New Issue
Block a user