fix: pad activation as uint8 view for float4 dtype
This commit is contained in:
@@ -80,8 +80,8 @@ def test_fused_swiglu_compilation():
|
||||
padded_offsets = torch.tensor([128], dtype=torch.int32, device=device) # padded to 128
|
||||
|
||||
# Pad activation to 128 rows
|
||||
x_padded = torch.zeros(128, K_packed, dtype=x_fp4.dtype, device=device)
|
||||
x_padded[:tokens] = x_fp4
|
||||
x_padded = torch.zeros(128, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
|
||||
x_padded.view(torch.uint8)[:tokens] = x_fp4.view(torch.uint8)
|
||||
|
||||
# Assemble scales (simplified — just pad + swizzle)
|
||||
from dsv4.ops.layouts import pad_and_swizzle_single, ceil_div as cutedsl_ceil_div
|
||||
|
||||
Reference in New Issue
Block a user