fix: scatter+per-expert-swizzle scale assembly (cudagraph-safe)
This commit is contained in:
@@ -163,40 +163,49 @@ class CuTeDSLMoERunner:
|
||||
padded_x_sf_buf, per_expert_bufs):
|
||||
"""Assemble 2D-side activation scales (cudagraph-safe, no CPU sync).
|
||||
|
||||
Matches the working assemble_scales_2d_side: pads each expert's scales
|
||||
to 128 rows, swizzles each expert block independently, then concatenates.
|
||||
Each expert's data is placed at 128-row-aligned offsets so that a
|
||||
fixed 128-row slice always contains only that expert's data.
|
||||
Two-phase approach:
|
||||
1. Scatter x_sf rows into 128-aligned positions in padded_x_sf (GPU-only)
|
||||
2. Per-expert: copy 128 rows from padded_x_sf, swizzle, accumulate
|
||||
|
||||
All operations are fixed-shape. No .item(), no .tolist(), no variable
|
||||
slicing with GPU scalars.
|
||||
"""
|
||||
num_experts = self.num_experts
|
||||
K_sf = x_sf.shape[1]
|
||||
|
||||
# Zero the padded buffer, then scatter each expert's rows at 128-aligned offsets
|
||||
# ── Phase 1: Scatter x_sf into 128-aligned padded buffer ──
|
||||
# Use GPU-computed indices (same as original _assemble_scales_cudagraph_safe)
|
||||
padded_x_sf = padded_x_sf_buf
|
||||
padded_x_sf.zero_()
|
||||
for e in range(num_experts):
|
||||
start = expert_offsets[e]
|
||||
end = expert_offsets[e + 1]
|
||||
num_rows = end - start
|
||||
# Place this expert's rows at offset e*128 in the padded buffer
|
||||
padded_x_sf[e * 128:e * 128 + num_rows, :K_sf] = x_sf[start:end]
|
||||
|
||||
# For each expert: zero the per-expert buf, copy from padded, swizzle
|
||||
tokens_per_expert = expert_offsets[1:] - expert_offsets[:-1]
|
||||
padded_rows_per_expert = ((tokens_per_expert + 127) // 128) * 128
|
||||
padded_expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32, device=x_sf.device)
|
||||
padded_expert_offsets[1:] = padded_rows_per_expert.cumsum(0)
|
||||
|
||||
total_rows = x_sf.shape[0]
|
||||
row_indices = torch.arange(total_rows, device=x_sf.device)
|
||||
expert_assign = torch.searchsorted(
|
||||
expert_offsets[1:], row_indices, right=False
|
||||
).clamp(max=num_experts - 1)
|
||||
local_row = row_indices - expert_offsets[expert_assign]
|
||||
dst_rows = padded_expert_offsets[expert_assign] + local_row
|
||||
padded_x_sf[dst_rows, :K_sf] = x_sf
|
||||
|
||||
# ── Phase 2: Per-expert swizzle ──
|
||||
# Each expert gets its own 128-row buf, copied from padded_x_sf at
|
||||
# the 128-aligned offset, then swizzled independently
|
||||
swizzled_parts = []
|
||||
for e in range(num_experts):
|
||||
buf = per_expert_bufs[e]
|
||||
buf.zero_()
|
||||
# Copy 128 rows starting at this expert's aligned offset
|
||||
buf[:, :K_sf] = padded_x_sf[e * 128:e * 128 + 128]
|
||||
|
||||
# Swizzle this expert's block (matches pad_and_swizzle_single per expert)
|
||||
swizzled = pad_and_swizzle_single(buf)
|
||||
swizzled_parts.append(swizzled)
|
||||
|
||||
# Concatenate all expert blocks (matches cat_byte_reinterpretable_tensors)
|
||||
# Concatenate all expert blocks
|
||||
all_flat = torch.cat([p.view(torch.uint8) for p in swizzled_parts], dim=0)
|
||||
all_flat = all_flat.view(torch.float8_e4m3fn)
|
||||
|
||||
return all_flat.reshape(num_experts * 128, -1)
|
||||
|
||||
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
|
||||
|
||||
Reference in New Issue
Block a user