fix: force contiguous on SF tensors before C++ call
This commit is contained in:
@@ -319,6 +319,15 @@ def fp8_nvfp4_mega_moe(y: torch.Tensor,
|
||||
"""
|
||||
l1_w, l1_w_sf = l1_weights
|
||||
l2_w, l2_w_sf = l2_weights
|
||||
|
||||
# Force contiguous on SF tensors — non-contiguous SF breaks TMA descriptors
|
||||
for name, t in [("l1_w_sf", l1_w_sf), ("l2_w_sf", l2_w_sf)]:
|
||||
if not t.is_contiguous():
|
||||
print(f"[contig-fix] {name}: was NOT contiguous, forcing", flush=True)
|
||||
# (assign back to correct variable)
|
||||
l1_w_sf = l1_w_sf.contiguous()
|
||||
l2_w_sf = l2_w_sf.contiguous()
|
||||
|
||||
for name, t in [("l1_w", l1_w), ("l1_w_sf", l1_w_sf),
|
||||
("l2_w", l2_w), ("l2_w_sf", l2_w_sf)]:
|
||||
print(f"[debug] {name}: dtype={t.dtype} shape={tuple(t.shape)} contig={t.is_contiguous()}", flush=True)
|
||||
@@ -330,7 +339,7 @@ def fp8_nvfp4_mega_moe(y: torch.Tensor,
|
||||
|
||||
_C.fp8_nvfp4_mega_moe(
|
||||
y,
|
||||
l1_weights, l2_weights,
|
||||
(l1_w, l1_w_sf), (l2_w, l2_w_sf),
|
||||
cumulative_local_expert_recv_stats,
|
||||
sym_buffer.buffer,
|
||||
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
|
||||
|
||||
Reference in New Issue
Block a user