diff --git a/tests/unit/test_fused_swiglu_kernel.py b/tests/unit/test_fused_swiglu_kernel.py index 7527b668..e6f7901f 100644 --- a/tests/unit/test_fused_swiglu_kernel.py +++ b/tests/unit/test_fused_swiglu_kernel.py @@ -80,8 +80,8 @@ def test_fused_swiglu_compilation(): padded_offsets = torch.tensor([128], dtype=torch.int32, device=device) # padded to 128 # Pad activation to 128 rows - x_padded = torch.zeros(128, K_packed, dtype=x_fp4.dtype, device=device) - x_padded[:tokens] = x_fp4 + x_padded = torch.zeros(128, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2) + x_padded.view(torch.uint8)[:tokens] = x_fp4.view(torch.uint8) # Assemble scales (simplified — just pad + swizzle) from dsv4.ops.layouts import pad_and_swizzle_single, ceil_div as cutedsl_ceil_div