Fix grouped_linear GEMM output buffer shape and extraction

- _output_buf_padded: (max_tokens * n_groups, o_lora_rank) — matches GEMM output
- Extraction: groups are stacked vertically, not horizontally
- Each group's output is (padded_rows, o_lora_rank) with o_lora_rank columns
This commit is contained in:
2026-06-03 22:26:40 +00:00
parent 92225b07e7
commit f57de06eb5

View File

@@ -224,9 +224,10 @@ class Nvfp4GroupedLinear:
dtype=torch.bfloat16, device=self.device
)
# Pre-allocate FLAT output buffer for grouped GEMM (graph capture)
# The GEMM produces (tokens_sum, n_dim) where n_dim = n_local_groups * o_lora_rank
# The GEMM produces (tokens_sum, n_dim) where n_dim = o_lora_rank
# tokens_sum = n_groups * padded_rows_per_group (max = n_groups * max_num_tokens)
self._output_buf_padded = torch.zeros(
self.max_num_tokens, self.n_local_groups * self.o_lora_rank,
self.max_num_tokens * self.n_local_groups, self.o_lora_rank,
dtype=torch.bfloat16, device=self.device
)
# Pre-allocate scale_a swizzle buffer for graph capture
@@ -396,16 +397,15 @@ class Nvfp4GroupedLinear:
)
# Extract real outputs and reshape
# GEMM output has the same layout as mat_a: groups-first with padding
# For CUDA graph capture (T=1 decode): use vectorized GPU gather — no Python loop.
# For T>1 prefill: Python loop is OK (not graph-captured).
# GEMM output layout: (tokens_sum, o_lora_rank) where tokens_sum = n_groups * padded_rows
# Groups are stacked vertically: group 0 at rows [0, padded_rows), group 1 at [padded_rows, 2*padded_rows), etc.
z_gem = z_gem if z_gem is not None else self._output_buf_padded
z = self._output_buf[:num_tokens]
if num_tokens == 1:
# Vectorized: gather_indices = [0, padded_T, 2*padded_T, ...] — GPU-only
gather_indices = self._expert_offsets_range_buf[:self.n_local_groups] * padded_rows_per_group - padded_rows_per_group
z_flat = z_gem[gather_indices] # (n_groups, o_rank) — GPU gather
z[:, :, :] = z_flat.unsqueeze(0) # (1, n_groups, o_rank)
z_flat = z_gem[gather_indices] # (n_groups, o_lora_rank) — GPU gather
z[:, :, :] = z_flat.unsqueeze(0) # (1, n_groups, o_lora_rank)
else:
for g in range(self.n_local_groups):
offset = g * padded_rows_per_group