diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py index e406d2bf..32cfbc02 100644 --- a/vllm/nvfp4_cutedsl.py +++ b/vllm/nvfp4_cutedsl.py @@ -163,40 +163,49 @@ class CuTeDSLMoERunner: padded_x_sf_buf, per_expert_bufs): """Assemble 2D-side activation scales (cudagraph-safe, no CPU sync). - Matches the working assemble_scales_2d_side: pads each expert's scales - to 128 rows, swizzles each expert block independently, then concatenates. - Each expert's data is placed at 128-row-aligned offsets so that a - fixed 128-row slice always contains only that expert's data. + Two-phase approach: + 1. Scatter x_sf rows into 128-aligned positions in padded_x_sf (GPU-only) + 2. Per-expert: copy 128 rows from padded_x_sf, swizzle, accumulate + + All operations are fixed-shape. No .item(), no .tolist(), no variable + slicing with GPU scalars. """ num_experts = self.num_experts K_sf = x_sf.shape[1] - # Zero the padded buffer, then scatter each expert's rows at 128-aligned offsets + # ── Phase 1: Scatter x_sf into 128-aligned padded buffer ── + # Use GPU-computed indices (same as original _assemble_scales_cudagraph_safe) padded_x_sf = padded_x_sf_buf padded_x_sf.zero_() - for e in range(num_experts): - start = expert_offsets[e] - end = expert_offsets[e + 1] - num_rows = end - start - # Place this expert's rows at offset e*128 in the padded buffer - padded_x_sf[e * 128:e * 128 + num_rows, :K_sf] = x_sf[start:end] - # For each expert: zero the per-expert buf, copy from padded, swizzle + tokens_per_expert = expert_offsets[1:] - expert_offsets[:-1] + padded_rows_per_expert = ((tokens_per_expert + 127) // 128) * 128 + padded_expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32, device=x_sf.device) + padded_expert_offsets[1:] = padded_rows_per_expert.cumsum(0) + + total_rows = x_sf.shape[0] + row_indices = torch.arange(total_rows, device=x_sf.device) + expert_assign = torch.searchsorted( + expert_offsets[1:], row_indices, right=False + ).clamp(max=num_experts - 1) + local_row = row_indices - expert_offsets[expert_assign] + dst_rows = padded_expert_offsets[expert_assign] + local_row + padded_x_sf[dst_rows, :K_sf] = x_sf + + # ── Phase 2: Per-expert swizzle ── + # Each expert gets its own 128-row buf, copied from padded_x_sf at + # the 128-aligned offset, then swizzled independently swizzled_parts = [] for e in range(num_experts): buf = per_expert_bufs[e] buf.zero_() - # Copy 128 rows starting at this expert's aligned offset buf[:, :K_sf] = padded_x_sf[e * 128:e * 128 + 128] - - # Swizzle this expert's block (matches pad_and_swizzle_single per expert) swizzled = pad_and_swizzle_single(buf) swizzled_parts.append(swizzled) - # Concatenate all expert blocks (matches cat_byte_reinterpretable_tensors) + # Concatenate all expert blocks all_flat = torch.cat([p.view(torch.uint8) for p in swizzled_parts], dim=0) all_flat = all_flat.view(torch.float8_e4m3fn) - return all_flat.reshape(num_experts * 128, -1) def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):