Fix grouped_linear GEMM output buffer shape and extraction
- _output_buf_padded: (max_tokens * n_groups, o_lora_rank) — matches GEMM output - Extraction: groups are stacked vertically, not horizontally - Each group's output is (padded_rows, o_lora_rank) with o_lora_rank columns
This commit is contained in:
@@ -224,9 +224,10 @@ class Nvfp4GroupedLinear:
|
||||
dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
# Pre-allocate FLAT output buffer for grouped GEMM (graph capture)
|
||||
# The GEMM produces (tokens_sum, n_dim) where n_dim = n_local_groups * o_lora_rank
|
||||
# The GEMM produces (tokens_sum, n_dim) where n_dim = o_lora_rank
|
||||
# tokens_sum = n_groups * padded_rows_per_group (max = n_groups * max_num_tokens)
|
||||
self._output_buf_padded = torch.zeros(
|
||||
self.max_num_tokens, self.n_local_groups * self.o_lora_rank,
|
||||
self.max_num_tokens * self.n_local_groups, self.o_lora_rank,
|
||||
dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
# Pre-allocate scale_a swizzle buffer for graph capture
|
||||
@@ -396,16 +397,15 @@ class Nvfp4GroupedLinear:
|
||||
)
|
||||
|
||||
# Extract real outputs and reshape
|
||||
# GEMM output has the same layout as mat_a: groups-first with padding
|
||||
# For CUDA graph capture (T=1 decode): use vectorized GPU gather — no Python loop.
|
||||
# For T>1 prefill: Python loop is OK (not graph-captured).
|
||||
# GEMM output layout: (tokens_sum, o_lora_rank) where tokens_sum = n_groups * padded_rows
|
||||
# Groups are stacked vertically: group 0 at rows [0, padded_rows), group 1 at [padded_rows, 2*padded_rows), etc.
|
||||
z_gem = z_gem if z_gem is not None else self._output_buf_padded
|
||||
z = self._output_buf[:num_tokens]
|
||||
if num_tokens == 1:
|
||||
# Vectorized: gather_indices = [0, padded_T, 2*padded_T, ...] — GPU-only
|
||||
gather_indices = self._expert_offsets_range_buf[:self.n_local_groups] * padded_rows_per_group - padded_rows_per_group
|
||||
z_flat = z_gem[gather_indices] # (n_groups, o_rank) — GPU gather
|
||||
z[:, :, :] = z_flat.unsqueeze(0) # (1, n_groups, o_rank)
|
||||
z_flat = z_gem[gather_indices] # (n_groups, o_lora_rank) — GPU gather
|
||||
z[:, :, :] = z_flat.unsqueeze(0) # (1, n_groups, o_lora_rank)
|
||||
else:
|
||||
for g in range(self.n_local_groups):
|
||||
offset = g * padded_rows_per_group
|
||||
|
||||
Reference in New Issue
Block a user