From 418e29f7f54eec74c3a16f09e317a2a835e1a3d9 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 17 May 2026 07:35:49 +0000 Subject: [PATCH] fix: per-expert scale assembly (match assemble_scales_2d_side) --- vllm/nvfp4_cutedsl.py | 75 ++++++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 41 deletions(-) diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py index 31b72dc8..7912762f 100644 --- a/vllm/nvfp4_cutedsl.py +++ b/vllm/nvfp4_cutedsl.py @@ -78,8 +78,6 @@ class CuTeDSLMoERunner: max_slots = self.max_num_tokens * self.top_k K_sf = cutedsl_ceil_div(self.hidden_size, 16) padded_cols = cutedsl_ceil_div(K_sf, 4) * 4 - # Worst case: 1 token per expert, each padded to 128 rows - max_padded_rows = self.num_experts * 128 # Slot -> token mapping: [0,0,...,0, 1,1,...,1, ...] (top_k repeats) self._token_indices = torch.arange( @@ -91,12 +89,13 @@ class CuTeDSLMoERunner: self._expert_offsets_buf = torch.zeros( self.num_experts + 1, dtype=torch.int32, device=self.device ) - self._padded_expert_offsets_buf = torch.zeros( - self.num_experts + 1, dtype=torch.int32, device=self.device - ) - self._padded_scales_buf = torch.zeros( - max_padded_rows, padded_cols, dtype=torch.float16, device=self.device - ).to(torch.float8_e4m3fn) + + # Per-expert scale buffers: each expert gets a 128-row block + # This matches assemble_scales_2d_side which pads+swizzles each expert independently + self._per_expert_scale_bufs = [ + torch.zeros(128, padded_cols, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn) + for _ in range(self.num_experts) + ] self._buffers_allocated = True @@ -145,46 +144,40 @@ class CuTeDSLMoERunner: def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets): """Assemble 2D-side activation scales (cudagraph-safe, no CPU sync). - Uses GPU-computed indices to scatter scale data into padded positions, - then applies the swizzle. Returns 2D tensor. + Matches the working assemble_scales_2d_side: pads each expert's scales + to 128 rows, swizzles each expert block independently, then concatenates. No .item(), no .tolist(), no Python control flow on GPU data. + + Fixed-shape: each expert gets exactly 128 rows (padded). We always + copy the full 128-row block from x_sf (zero-padded rows are harmless). """ num_experts = self.num_experts K_sf = x_sf.shape[1] - padded_cols = cutedsl_ceil_div(K_sf, 4) * 4 - # Compute tokens per expert (GPU) - tokens_per_expert = expert_offsets[1:] - expert_offsets[:-1] + # For each expert: zero the buffer, scatter its rows, swizzle, flatten + swizzled_parts = [] + for e in range(num_experts): + buf = self._per_expert_scale_bufs[e] + buf.zero_() + + start = expert_offsets[e] + # Always copy 128 rows — extra rows will be zeros from x_sf padding + # or from the zero-initialized buffer + # Use a fixed-shape slice: buf is always (128, padded_cols) + # x_sf may not have 128 rows for this expert, but that's fine — + # the buffer is zero-initialized and we overwrite with whatever exists + buf[:, :K_sf] = x_sf[start:start + 128] + + # Swizzle this expert's block (matches pad_and_swizzle_single per expert) + swizzled = pad_and_swizzle_single(buf) + swizzled_parts.append(swizzled) - # Compute padded rows per expert (round up to 128) - padded_rows_per_expert = ((tokens_per_expert + 127) // 128) * 128 + # Concatenate all expert blocks (matches cat_byte_reinterpretable_tensors) + # float8_e4m3fn is a 1-byte float type — cat via uint8 view + all_flat = torch.cat([p.view(torch.uint8) for p in swizzled_parts], dim=0) + all_flat = all_flat.view(torch.float8_e4m3fn) - # Compute padded offsets - padded_expert_offsets = self._padded_expert_offsets_buf - padded_expert_offsets.zero_() - padded_expert_offsets[1:] = padded_rows_per_expert.cumsum(0) - - # Use the FULL pre-allocated scales buffer (no GPU scalar slicing) - padded_scales = self._padded_scales_buf - padded_scales.zero_() - - # Build index mapping: for each row in x_sf, which expert does it belong to? - total_rows = x_sf.shape[0] - row_indices = self._token_indices[:total_rows] - expert_assign = torch.searchsorted( - expert_offsets[1:], row_indices, right=False - ).clamp(max=num_experts - 1) - - # Destination row in padded buffer - local_row = row_indices - expert_offsets[expert_assign] - dst_rows = padded_expert_offsets[expert_assign] + local_row - - # Scatter x_sf into padded_scales - padded_scales[dst_rows, :K_sf] = x_sf - - # Apply swizzle, reshape to 2D (element count preserved by swizzle) - swizzled = pad_and_swizzle_single(padded_scales) - return swizzled.reshape(padded_scales.shape[0], -1) + return all_flat.reshape(num_experts * 128, -1) def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None): """Run the NVFP4 MoE forward pass.