fix: pad activation as uint8 view for float4 dtype

This commit is contained in:
2026-06-02 08:18:26 +00:00
parent 024be1a60b
commit fa769b6214

View File

@@ -80,8 +80,8 @@ def test_fused_swiglu_compilation():
padded_offsets = torch.tensor([128], dtype=torch.int32, device=device) # padded to 128
# Pad activation to 128 rows
x_padded = torch.zeros(128, K_packed, dtype=x_fp4.dtype, device=device)
x_padded[:tokens] = x_fp4
x_padded = torch.zeros(128, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
x_padded.view(torch.uint8)[:tokens] = x_fp4.view(torch.uint8)
# Assemble scales (simplified — just pad + swizzle)
from dsv4.ops.layouts import pad_and_swizzle_single, ceil_div as cutedsl_ceil_div