fix: padded x_sf buffer for fixed-shape scale assembly
This commit is contained in:
@@ -69,8 +69,8 @@ class CuTeDSLMoERunner:
|
||||
self._token_indices = None
|
||||
self._expert_id_range = None
|
||||
self._expert_offsets_buf = None
|
||||
self._padded_scales_buf = None
|
||||
self._padded_expert_offsets_buf = None
|
||||
self._per_expert_scale_bufs = None
|
||||
self._padded_x_sf_buf = None
|
||||
self._buffers_allocated = False
|
||||
|
||||
def _allocate_buffers(self):
|
||||
@@ -97,6 +97,12 @@ class CuTeDSLMoERunner:
|
||||
for _ in range(self.num_experts)
|
||||
]
|
||||
|
||||
# Padded x_sf buffer: num_experts * 128 rows so that fixed-shape slices
|
||||
# x_sf[start:start+128] always have 128 rows (extra rows are zeros)
|
||||
self._padded_x_sf_buf = torch.zeros(
|
||||
self.num_experts * 128, padded_cols, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn)
|
||||
|
||||
self._buffers_allocated = True
|
||||
|
||||
def _ensure_stacked(self):
|
||||
@@ -149,11 +155,16 @@ class CuTeDSLMoERunner:
|
||||
No .item(), no .tolist(), no Python control flow on GPU data.
|
||||
|
||||
Fixed-shape: each expert gets exactly 128 rows (padded). We always
|
||||
copy the full 128-row block from x_sf (zero-padded rows are harmless).
|
||||
copy the full 128-row block (zero-padded rows are harmless).
|
||||
"""
|
||||
num_experts = self.num_experts
|
||||
K_sf = x_sf.shape[1]
|
||||
|
||||
# Pad x_sf to num_experts * 128 rows so fixed-shape slices always work
|
||||
padded_x_sf = self._padded_x_sf_buf
|
||||
padded_x_sf.zero_()
|
||||
padded_x_sf[:x_sf.shape[0], :K_sf] = x_sf
|
||||
|
||||
# For each expert: zero the buffer, scatter its rows, swizzle, flatten
|
||||
swizzled_parts = []
|
||||
for e in range(num_experts):
|
||||
@@ -161,12 +172,8 @@ class CuTeDSLMoERunner:
|
||||
buf.zero_()
|
||||
|
||||
start = expert_offsets[e]
|
||||
# Always copy 128 rows — extra rows will be zeros from x_sf padding
|
||||
# or from the zero-initialized buffer
|
||||
# Use a fixed-shape slice: buf is always (128, padded_cols)
|
||||
# x_sf may not have 128 rows for this expert, but that's fine —
|
||||
# the buffer is zero-initialized and we overwrite with whatever exists
|
||||
buf[:, :K_sf] = x_sf[start:start + 128]
|
||||
# Always copy 128 rows from padded buffer — extra rows are zeros
|
||||
buf[:, :K_sf] = padded_x_sf[start:start + 128]
|
||||
|
||||
# Swizzle this expert's block (matches pad_and_swizzle_single per expert)
|
||||
swizzled = pad_and_swizzle_single(buf)
|
||||
|
||||
Reference in New Issue
Block a user