Fix scale assembly: variable-size per-expert padding matching GEMM offsets

- Compute padded_expert_offsets from real expert_offsets (ceil to 128)
- Scatter x_sf into padded positions matching those offsets
- Per-expert swizzle in 128-row chunks (supports >128 tokens per expert)
- Pad slot_hidden/activated using same padded offsets for GEMM input
- Pre-allocated buffers sized for max_tokens*top_k (not num_experts*128)
This commit is contained in:
2026-05-17 13:55:10 +00:00
parent 0d3c928ff2
commit bf22b6f0e4

View File

@@ -113,12 +113,13 @@ class CuTeDSLMoERunner:
]
# Padded x_sf buffers for Phase 1 scatter.
# Fixed 128 rows per expert → num_experts * 128 total rows.
# Sized for max_num_tokens * top_k rows (worst case: all tokens in one expert).
max_sf_rows = self.max_num_tokens * self.top_k
self._padded_x_sf_buf_l1 = torch.zeros(
self.num_experts * 128, padded_cols_l1, dtype=torch.float16, device=self.device
max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
self._padded_x_sf_buf_l2 = torch.zeros(
self.num_experts * 128, padded_cols_l2, dtype=torch.float16, device=self.device
max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
# Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture)
@@ -135,22 +136,20 @@ class CuTeDSLMoERunner:
self.max_num_tokens * self.top_k, device=self.device
)
# Padded hidden/activated buffers: num_experts * 128 rows
padded_size = self.num_experts * 128
# Padded hidden/activated buffers: max_num_tokens * top_k rows (rounded to 128)
max_slots = self.max_num_tokens * self.top_k
padded_max_slots = ((max_slots + 127) // 128) * 128
self._padded_hidden_buf = torch.zeros(
padded_size, self.hidden_size, dtype=torch.bfloat16, device=self.device
padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device
)
self._padded_activated_buf = torch.zeros(
padded_size, self.intermediate_size, dtype=torch.bfloat16, device=self.device
padded_max_slots, self.intermediate_size, dtype=torch.bfloat16, device=self.device
)
# Padded expert offsets: [0, 128, 256, ...] (fixed, 128 tokens per expert)
# Padded expert offsets buffer (num_experts + 1)
self._padded_expert_offsets_buf = torch.zeros(
self.num_experts + 1, dtype=torch.int32, device=self.device
)
self._padded_expert_offsets_buf[1:] = torch.arange(
1, self.num_experts + 1, dtype=torch.int32, device=self.device
) * 128
self._buffers_allocated = True
@@ -224,58 +223,57 @@ class CuTeDSLMoERunner:
padded_x_sf_buf, per_expert_bufs):
"""Assemble 2D-side activation scales (cudagraph-safe, no CPU sync).
Per-expert swizzle using pre-allocated buffers, then concatenate.
The per-expert loop is a fixed-size Python loop (num_experts is constant),
so cudagraph captures it as a static unrolled sequence.
Each expert's scale data is swizzled independently and stacked.
The output shape depends on expert_offsets (GPU tensor), but during
cudagraph capture, expert_offsets is deterministic (fixed token budget).
Each expert's scale rows are padded to 128, then swizzled independently.
The GEMM reads scale_a according to padded_expert_offsets, which
matches the layout produced here.
"""
num_experts = self.num_experts
K_sf = x_sf.shape[1]
padded_x_sf = padded_x_sf_buf
padded_x_sf.zero_()
# Phase 1: Scatter x_sf into 128-aligned per-expert sections
# Each expert gets a fixed 128-row slot in padded_x_sf (at offset e*128).
# Tokens beyond 128 per expert wrap to zero rows (harmless — zero scale
# means zero contribution in GEMM). All indexing is fixed-shape.
# Compute padded expert offsets (each expert padded to 128 rows)
tokens_per_expert = expert_offsets[1:] - expert_offsets[:-1]
padded_rows_per_expert = ((tokens_per_expert + 127) // 128) * 128
padded_expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32, device=x_sf.device)
padded_expert_offsets[1:] = padded_rows_per_expert.cumsum(0)
# Phase 1: Scatter x_sf into padded per-expert sections
total_rows = x_sf.shape[0]
row_indices = self._row_indices_buf[:total_rows]
expert_assign = torch.searchsorted(
expert_offsets[1:], row_indices, right=True
).clamp(max=num_experts - 1)
local_row = row_indices - expert_offsets[expert_assign]
# Clamp local_row to [0, 127] — rows beyond 128 go to row 0 (overwriting,
# but row 0 already has valid data, and the extra row is ignored by GEMM)
clamped_local = local_row.clamp(max=127)
dst_rows = expert_assign * 128 + clamped_local
dst_rows = padded_expert_offsets[expert_assign] + local_row
padded_x_sf[dst_rows, :K_sf] = x_sf
# Phase 2: Per-expert swizzle and concatenate
# Each expert gets at most padded_x_sf[e*128 : (e+1)*128] for the first 128 rows.
# For experts with >128 tokens, we'd need multiple chunks, but during
# cudagraph capture the token budget is fixed, and the GEMM uses expert_offsets
# to determine how many rows each expert gets.
#
# Strategy: always swizzle 128 rows per expert (fixed loop), zero-pad shorter experts.
# The GEMM only reads the rows indicated by expert_offsets.
swizzled_parts = []
for e in range(num_experts):
n_padded = padded_rows_per_expert[e]
start = padded_expert_offsets[e]
buf = per_expert_bufs[e]
buf.zero_()
# Copy from padded_x_sf at this expert's 128-aligned offset
# Always copy 128 rows (fixed shape for cudagraph)
src_start = e * 128
buf[:, :K_sf] = padded_x_sf[src_start:src_start + 128]
swizzled = pad_and_swizzle_single(buf)
swizzled_parts.append(swizzled)
# Process in 128-row chunks
offset = start
remaining = n_padded
while remaining > 0:
buf.zero_()
chunk = min(remaining, 128)
buf[:chunk, :K_sf] = padded_x_sf[offset:offset + chunk]
swizzled = pad_and_swizzle_single(buf)
swizzled_parts.append(swizzled)
offset += 128
remaining -= 128
if n_padded == 0:
buf.zero_()
swizzled = pad_and_swizzle_single(buf)
swizzled_parts.append(swizzled)
# Concatenate all expert blocks (byte-reinterpretable)
all_flat = torch.cat([p.view(torch.uint8) for p in swizzled_parts], dim=0)
all_flat = all_flat.view(torch.float8_e4m3fn)
return all_flat.reshape(num_experts * 128, -1)
total_padded = padded_expert_offsets[num_experts]
return all_flat.reshape(total_padded, -1)
def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids):
"""Compute activation global scales from a warmup forward pass.
@@ -345,20 +343,11 @@ class CuTeDSLMoERunner:
"""Run the NVFP4 MoE forward pass.
Handles global→local expert ID remapping for expert parallelism.
topk_ids contains GLOBAL expert IDs (0..n_routed_experts-1).
This runner only handles local experts
[experts_start_idx, experts_start_idx + num_experts).
Non-local tokens get zero weight and are clamped to expert 0
(harmless — zero-weighted output contributes nothing to scatter_add).
Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes.
PADDING STRATEGY: Each expert is padded to exactly 128 tokens in
the GEMM input. This ensures scale_a has a fixed layout (128 rows
per expert) that matches the padded expert_offsets. The GEMM output
for padding tokens is discarded by scatter_add (sorted_token_ids
only refers to real tokens).
Each expert's slots are padded to multiples of 128 for the GEMM.
expert_offsets is [0, padded_e0, padded_e0+padded_e1, ...].
scale_a is produced at those same offsets.
"""
num_tokens = hidden_states.shape[0]
top_k = topk_ids.shape[1]
@@ -381,28 +370,36 @@ class CuTeDSLMoERunner:
sort_idx = flat_ids.argsort(stable=True)
sorted_ids = flat_ids[sort_idx]
sorted_weights = flat_weights[sort_idx]
sorted_token_ids = token_indices[sort_idx] # GPU tensor, no .cpu()
sorted_token_ids = token_indices[sort_idx]
# Expert offsets (GPU-only, never touches CPU)
# Expert offsets (real token counts)
expert_id_range = self._expert_id_range
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
expert_offsets = self._expert_offsets_buf
expert_offsets.zero_()
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
# -- Pad to 128 tokens per expert --
# The GEMM and scale assembly expect a fixed layout where each expert
# gets exactly 128 rows. Pad slot_hidden to num_experts * 128 rows,
# and set expert_offsets to [0, 128, 256, ...].
# Padded expert offsets (each expert padded to 128)
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
padded_expert_offsets = self._padded_expert_offsets_buf
# Already filled in _allocate_buffers with [0, 128, 256, ...]
padded_expert_offsets.zero_()
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
total_padded_slots = padded_expert_offsets[self.num_experts]
# Gather hidden states into slot order, then pad to 128 per expert
# -- Gather hidden states into slot order, pad to 128 per expert --
slot_hidden = hidden_states[sorted_token_ids]
padded_size = self.num_experts * 128
padded_hidden = self._padded_hidden_buf[:padded_size]
padded_hidden = self._padded_hidden_buf[:total_padded_slots]
padded_hidden.zero_()
padded_hidden[:num_slots] = slot_hidden
# Scatter real tokens into padded positions
# Each expert e: tokens are at [expert_offsets[e], expert_offsets[e+1])
# In padded buffer: tokens are at [padded_expert_offsets[e], padded_expert_offsets[e]+tokens_per_expert[e])
row_indices = self._row_indices_buf[:num_slots]
expert_assign = torch.searchsorted(
expert_offsets[1:], row_indices, right=True
).clamp(max=self.num_experts - 1)
local_row = row_indices - expert_offsets[expert_assign]
padded_dst = padded_expert_offsets[expert_assign] + local_row
padded_hidden[padded_dst] = slot_hidden
# === L1: gate + up ===
x_fp4, x_sf = quantize_activation_nvfp4(
@@ -410,8 +407,7 @@ class CuTeDSLMoERunner:
)
l1_scale_a = self._assemble_scales_cudagraph_safe(
x_sf[:num_slots], # Only real tokens have scale data
expert_offsets[:self.num_experts + 1],
x_sf, expert_offsets[:self.num_experts + 1],
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
)
l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
@@ -423,27 +419,25 @@ class CuTeDSLMoERunner:
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
)
# Only keep real token outputs (padding rows are garbage)
l1_out = l1_out[:num_slots]
# Extract real token outputs from padded GEMM output
l1_out_real = l1_out[padded_dst]
# === SiLU(gate) * up ===
gate = l1_out[:, :self.intermediate_size]
up = l1_out[:, self.intermediate_size:]
gate = l1_out_real[:, :self.intermediate_size]
up = l1_out_real[:, self.intermediate_size:]
activated = torch.nn.functional.silu(gate) * up
# === L2: down ===
# Pad activated to 128 tokens per expert for L2 GEMM
padded_activated = self._padded_activated_buf[:padded_size]
padded_activated = self._padded_activated_buf[:total_padded_slots]
padded_activated.zero_()
padded_activated[:num_slots] = activated
padded_activated[padded_dst] = activated
l2_x_fp4, l2_x_sf = quantize_activation_nvfp4(
padded_activated, self._l2_activation_global_scale
)
l2_scale_a = self._assemble_scales_cudagraph_safe(
l2_x_sf[:num_slots],
expert_offsets[:self.num_experts + 1],
l2_x_sf, expert_offsets[:self.num_experts + 1],
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
)
l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
@@ -455,13 +449,12 @@ class CuTeDSLMoERunner:
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
)
# Only keep real token outputs
l2_out = l2_out[:num_slots]
l2_out_real = l2_out[padded_dst]
# === Scatter -> final output ===
y = self._output_buf[:num_tokens]
y.zero_()
weighted_out = l2_out * sorted_weights.unsqueeze(1).to(l2_out.dtype)
weighted_out = l2_out_real * sorted_weights.unsqueeze(1).to(l2_out_real.dtype)
y.scatter_add_(
0,
sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size),