fix: per-expert scale assembly (match assemble_scales_2d_side)

This commit is contained in:
2026-05-17 07:35:49 +00:00
parent 7b95e76723
commit 418e29f7f5

View File

@@ -78,8 +78,6 @@ class CuTeDSLMoERunner:
max_slots = self.max_num_tokens * self.top_k
K_sf = cutedsl_ceil_div(self.hidden_size, 16)
padded_cols = cutedsl_ceil_div(K_sf, 4) * 4
# Worst case: 1 token per expert, each padded to 128 rows
max_padded_rows = self.num_experts * 128
# Slot -> token mapping: [0,0,...,0, 1,1,...,1, ...] (top_k repeats)
self._token_indices = torch.arange(
@@ -91,12 +89,13 @@ class CuTeDSLMoERunner:
self._expert_offsets_buf = torch.zeros(
self.num_experts + 1, dtype=torch.int32, device=self.device
)
self._padded_expert_offsets_buf = torch.zeros(
self.num_experts + 1, dtype=torch.int32, device=self.device
)
self._padded_scales_buf = torch.zeros(
max_padded_rows, padded_cols, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
# Per-expert scale buffers: each expert gets a 128-row block
# This matches assemble_scales_2d_side which pads+swizzles each expert independently
self._per_expert_scale_bufs = [
torch.zeros(128, padded_cols, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
for _ in range(self.num_experts)
]
self._buffers_allocated = True
@@ -145,46 +144,40 @@ class CuTeDSLMoERunner:
def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets):
"""Assemble 2D-side activation scales (cudagraph-safe, no CPU sync).
Uses GPU-computed indices to scatter scale data into padded positions,
then applies the swizzle. Returns 2D tensor.
Matches the working assemble_scales_2d_side: pads each expert's scales
to 128 rows, swizzles each expert block independently, then concatenates.
No .item(), no .tolist(), no Python control flow on GPU data.
Fixed-shape: each expert gets exactly 128 rows (padded). We always
copy the full 128-row block from x_sf (zero-padded rows are harmless).
"""
num_experts = self.num_experts
K_sf = x_sf.shape[1]
padded_cols = cutedsl_ceil_div(K_sf, 4) * 4
# Compute tokens per expert (GPU)
tokens_per_expert = expert_offsets[1:] - expert_offsets[:-1]
# For each expert: zero the buffer, scatter its rows, swizzle, flatten
swizzled_parts = []
for e in range(num_experts):
buf = self._per_expert_scale_bufs[e]
buf.zero_()
start = expert_offsets[e]
# Always copy 128 rows — extra rows will be zeros from x_sf padding
# or from the zero-initialized buffer
# Use a fixed-shape slice: buf is always (128, padded_cols)
# x_sf may not have 128 rows for this expert, but that's fine —
# the buffer is zero-initialized and we overwrite with whatever exists
buf[:, :K_sf] = x_sf[start:start + 128]
# Swizzle this expert's block (matches pad_and_swizzle_single per expert)
swizzled = pad_and_swizzle_single(buf)
swizzled_parts.append(swizzled)
# Compute padded rows per expert (round up to 128)
padded_rows_per_expert = ((tokens_per_expert + 127) // 128) * 128
# Concatenate all expert blocks (matches cat_byte_reinterpretable_tensors)
# float8_e4m3fn is a 1-byte float type — cat via uint8 view
all_flat = torch.cat([p.view(torch.uint8) for p in swizzled_parts], dim=0)
all_flat = all_flat.view(torch.float8_e4m3fn)
# Compute padded offsets
padded_expert_offsets = self._padded_expert_offsets_buf
padded_expert_offsets.zero_()
padded_expert_offsets[1:] = padded_rows_per_expert.cumsum(0)
# Use the FULL pre-allocated scales buffer (no GPU scalar slicing)
padded_scales = self._padded_scales_buf
padded_scales.zero_()
# Build index mapping: for each row in x_sf, which expert does it belong to?
total_rows = x_sf.shape[0]
row_indices = self._token_indices[:total_rows]
expert_assign = torch.searchsorted(
expert_offsets[1:], row_indices, right=False
).clamp(max=num_experts - 1)
# Destination row in padded buffer
local_row = row_indices - expert_offsets[expert_assign]
dst_rows = padded_expert_offsets[expert_assign] + local_row
# Scatter x_sf into padded_scales
padded_scales[dst_rows, :K_sf] = x_sf
# Apply swizzle, reshape to 2D (element count preserved by swizzle)
swizzled = pad_and_swizzle_single(padded_scales)
return swizzled.reshape(padded_scales.shape[0], -1)
return all_flat.reshape(num_experts * 128, -1)
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
"""Run the NVFP4 MoE forward pass.