diff --git a/dsv4/layers/grouped_linear.py b/dsv4/layers/grouped_linear.py index 83b79826..30c8a555 100644 --- a/dsv4/layers/grouped_linear.py +++ b/dsv4/layers/grouped_linear.py @@ -224,9 +224,10 @@ class Nvfp4GroupedLinear: dtype=torch.bfloat16, device=self.device ) # Pre-allocate FLAT output buffer for grouped GEMM (graph capture) - # The GEMM produces (tokens_sum, n_dim) where n_dim = n_local_groups * o_lora_rank + # The GEMM produces (tokens_sum, n_dim) where n_dim = o_lora_rank + # tokens_sum = n_groups * padded_rows_per_group (max = n_groups * max_num_tokens) self._output_buf_padded = torch.zeros( - self.max_num_tokens, self.n_local_groups * self.o_lora_rank, + self.max_num_tokens * self.n_local_groups, self.o_lora_rank, dtype=torch.bfloat16, device=self.device ) # Pre-allocate scale_a swizzle buffer for graph capture @@ -396,16 +397,15 @@ class Nvfp4GroupedLinear: ) # Extract real outputs and reshape - # GEMM output has the same layout as mat_a: groups-first with padding - # For CUDA graph capture (T=1 decode): use vectorized GPU gather — no Python loop. - # For T>1 prefill: Python loop is OK (not graph-captured). + # GEMM output layout: (tokens_sum, o_lora_rank) where tokens_sum = n_groups * padded_rows + # Groups are stacked vertically: group 0 at rows [0, padded_rows), group 1 at [padded_rows, 2*padded_rows), etc. z_gem = z_gem if z_gem is not None else self._output_buf_padded z = self._output_buf[:num_tokens] if num_tokens == 1: # Vectorized: gather_indices = [0, padded_T, 2*padded_T, ...] — GPU-only gather_indices = self._expert_offsets_range_buf[:self.n_local_groups] * padded_rows_per_group - padded_rows_per_group - z_flat = z_gem[gather_indices] # (n_groups, o_rank) — GPU gather - z[:, :, :] = z_flat.unsqueeze(0) # (1, n_groups, o_rank) + z_flat = z_gem[gather_indices] # (n_groups, o_lora_rank) — GPU gather + z[:, :, :] = z_flat.unsqueeze(0) # (1, n_groups, o_lora_rank) else: for g in range(self.n_local_groups): offset = g * padded_rows_per_group