diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py index 7912762f..88d5aec9 100644 --- a/vllm/nvfp4_cutedsl.py +++ b/vllm/nvfp4_cutedsl.py @@ -69,8 +69,8 @@ class CuTeDSLMoERunner: self._token_indices = None self._expert_id_range = None self._expert_offsets_buf = None - self._padded_scales_buf = None - self._padded_expert_offsets_buf = None + self._per_expert_scale_bufs = None + self._padded_x_sf_buf = None self._buffers_allocated = False def _allocate_buffers(self): @@ -97,6 +97,12 @@ class CuTeDSLMoERunner: for _ in range(self.num_experts) ] + # Padded x_sf buffer: num_experts * 128 rows so that fixed-shape slices + # x_sf[start:start+128] always have 128 rows (extra rows are zeros) + self._padded_x_sf_buf = torch.zeros( + self.num_experts * 128, padded_cols, dtype=torch.float16, device=self.device + ).to(torch.float8_e4m3fn) + self._buffers_allocated = True def _ensure_stacked(self): @@ -149,11 +155,16 @@ class CuTeDSLMoERunner: No .item(), no .tolist(), no Python control flow on GPU data. Fixed-shape: each expert gets exactly 128 rows (padded). We always - copy the full 128-row block from x_sf (zero-padded rows are harmless). + copy the full 128-row block (zero-padded rows are harmless). """ num_experts = self.num_experts K_sf = x_sf.shape[1] + # Pad x_sf to num_experts * 128 rows so fixed-shape slices always work + padded_x_sf = self._padded_x_sf_buf + padded_x_sf.zero_() + padded_x_sf[:x_sf.shape[0], :K_sf] = x_sf + # For each expert: zero the buffer, scatter its rows, swizzle, flatten swizzled_parts = [] for e in range(num_experts): @@ -161,12 +172,8 @@ class CuTeDSLMoERunner: buf.zero_() start = expert_offsets[e] - # Always copy 128 rows — extra rows will be zeros from x_sf padding - # or from the zero-initialized buffer - # Use a fixed-shape slice: buf is always (128, padded_cols) - # x_sf may not have 128 rows for this expert, but that's fine — - # the buffer is zero-initialized and we overwrite with whatever exists - buf[:, :K_sf] = x_sf[start:start + 128] + # Always copy 128 rows from padded buffer — extra rows are zeros + buf[:, :K_sf] = padded_x_sf[start:start + 128] # Swizzle this expert's block (matches pad_and_swizzle_single per expert) swizzled = pad_and_swizzle_single(buf)