fix: padded x_sf buffer for fixed-shape scale assembly

This commit is contained in:
2026-05-17 07:37:04 +00:00
parent 418e29f7f5
commit 8642946274

View File

@@ -69,8 +69,8 @@ class CuTeDSLMoERunner:
self._token_indices = None
self._expert_id_range = None
self._expert_offsets_buf = None
self._padded_scales_buf = None
self._padded_expert_offsets_buf = None
self._per_expert_scale_bufs = None
self._padded_x_sf_buf = None
self._buffers_allocated = False
def _allocate_buffers(self):
@@ -97,6 +97,12 @@ class CuTeDSLMoERunner:
for _ in range(self.num_experts)
]
# Padded x_sf buffer: num_experts * 128 rows so that fixed-shape slices
# x_sf[start:start+128] always have 128 rows (extra rows are zeros)
self._padded_x_sf_buf = torch.zeros(
self.num_experts * 128, padded_cols, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
self._buffers_allocated = True
def _ensure_stacked(self):
@@ -149,11 +155,16 @@ class CuTeDSLMoERunner:
No .item(), no .tolist(), no Python control flow on GPU data.
Fixed-shape: each expert gets exactly 128 rows (padded). We always
copy the full 128-row block from x_sf (zero-padded rows are harmless).
copy the full 128-row block (zero-padded rows are harmless).
"""
num_experts = self.num_experts
K_sf = x_sf.shape[1]
# Pad x_sf to num_experts * 128 rows so fixed-shape slices always work
padded_x_sf = self._padded_x_sf_buf
padded_x_sf.zero_()
padded_x_sf[:x_sf.shape[0], :K_sf] = x_sf
# For each expert: zero the buffer, scatter its rows, swizzle, flatten
swizzled_parts = []
for e in range(num_experts):
@@ -161,12 +172,8 @@ class CuTeDSLMoERunner:
buf.zero_()
start = expert_offsets[e]
# Always copy 128 rows — extra rows will be zeros from x_sf padding
# or from the zero-initialized buffer
# Use a fixed-shape slice: buf is always (128, padded_cols)
# x_sf may not have 128 rows for this expert, but that's fine —
# the buffer is zero-initialized and we overwrite with whatever exists
buf[:, :K_sf] = x_sf[start:start + 128]
# Always copy 128 rows from padded buffer — extra rows are zeros
buf[:, :K_sf] = padded_x_sf[start:start + 128]
# Swizzle this expert's block (matches pad_and_swizzle_single per expert)
swizzled = pad_and_swizzle_single(buf)