Fix scale assembly: variable-size per-expert padding matching GEMM offsets
- Compute padded_expert_offsets from real expert_offsets (ceil to 128) - Scatter x_sf into padded positions matching those offsets - Per-expert swizzle in 128-row chunks (supports >128 tokens per expert) - Pad slot_hidden/activated using same padded offsets for GEMM input - Pre-allocated buffers sized for max_tokens*top_k (not num_experts*128)
This commit is contained in:
@@ -113,12 +113,13 @@ class CuTeDSLMoERunner:
|
||||
]
|
||||
|
||||
# Padded x_sf buffers for Phase 1 scatter.
|
||||
# Fixed 128 rows per expert → num_experts * 128 total rows.
|
||||
# Sized for max_num_tokens * top_k rows (worst case: all tokens in one expert).
|
||||
max_sf_rows = self.max_num_tokens * self.top_k
|
||||
self._padded_x_sf_buf_l1 = torch.zeros(
|
||||
self.num_experts * 128, padded_cols_l1, dtype=torch.float16, device=self.device
|
||||
max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn)
|
||||
self._padded_x_sf_buf_l2 = torch.zeros(
|
||||
self.num_experts * 128, padded_cols_l2, dtype=torch.float16, device=self.device
|
||||
max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn)
|
||||
|
||||
# Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture)
|
||||
@@ -135,22 +136,20 @@ class CuTeDSLMoERunner:
|
||||
self.max_num_tokens * self.top_k, device=self.device
|
||||
)
|
||||
|
||||
# Padded hidden/activated buffers: num_experts * 128 rows
|
||||
padded_size = self.num_experts * 128
|
||||
# Padded hidden/activated buffers: max_num_tokens * top_k rows (rounded to 128)
|
||||
max_slots = self.max_num_tokens * self.top_k
|
||||
padded_max_slots = ((max_slots + 127) // 128) * 128
|
||||
self._padded_hidden_buf = torch.zeros(
|
||||
padded_size, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
||||
padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
self._padded_activated_buf = torch.zeros(
|
||||
padded_size, self.intermediate_size, dtype=torch.bfloat16, device=self.device
|
||||
padded_max_slots, self.intermediate_size, dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
|
||||
# Padded expert offsets: [0, 128, 256, ...] (fixed, 128 tokens per expert)
|
||||
# Padded expert offsets buffer (num_experts + 1)
|
||||
self._padded_expert_offsets_buf = torch.zeros(
|
||||
self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self._padded_expert_offsets_buf[1:] = torch.arange(
|
||||
1, self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
) * 128
|
||||
|
||||
self._buffers_allocated = True
|
||||
|
||||
@@ -224,58 +223,57 @@ class CuTeDSLMoERunner:
|
||||
padded_x_sf_buf, per_expert_bufs):
|
||||
"""Assemble 2D-side activation scales (cudagraph-safe, no CPU sync).
|
||||
|
||||
Per-expert swizzle using pre-allocated buffers, then concatenate.
|
||||
The per-expert loop is a fixed-size Python loop (num_experts is constant),
|
||||
so cudagraph captures it as a static unrolled sequence.
|
||||
|
||||
Each expert's scale data is swizzled independently and stacked.
|
||||
The output shape depends on expert_offsets (GPU tensor), but during
|
||||
cudagraph capture, expert_offsets is deterministic (fixed token budget).
|
||||
Each expert's scale rows are padded to 128, then swizzled independently.
|
||||
The GEMM reads scale_a according to padded_expert_offsets, which
|
||||
matches the layout produced here.
|
||||
"""
|
||||
num_experts = self.num_experts
|
||||
K_sf = x_sf.shape[1]
|
||||
padded_x_sf = padded_x_sf_buf
|
||||
padded_x_sf.zero_()
|
||||
|
||||
# Phase 1: Scatter x_sf into 128-aligned per-expert sections
|
||||
# Each expert gets a fixed 128-row slot in padded_x_sf (at offset e*128).
|
||||
# Tokens beyond 128 per expert wrap to zero rows (harmless — zero scale
|
||||
# means zero contribution in GEMM). All indexing is fixed-shape.
|
||||
# Compute padded expert offsets (each expert padded to 128 rows)
|
||||
tokens_per_expert = expert_offsets[1:] - expert_offsets[:-1]
|
||||
padded_rows_per_expert = ((tokens_per_expert + 127) // 128) * 128
|
||||
padded_expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32, device=x_sf.device)
|
||||
padded_expert_offsets[1:] = padded_rows_per_expert.cumsum(0)
|
||||
|
||||
# Phase 1: Scatter x_sf into padded per-expert sections
|
||||
total_rows = x_sf.shape[0]
|
||||
row_indices = self._row_indices_buf[:total_rows]
|
||||
expert_assign = torch.searchsorted(
|
||||
expert_offsets[1:], row_indices, right=True
|
||||
).clamp(max=num_experts - 1)
|
||||
local_row = row_indices - expert_offsets[expert_assign]
|
||||
# Clamp local_row to [0, 127] — rows beyond 128 go to row 0 (overwriting,
|
||||
# but row 0 already has valid data, and the extra row is ignored by GEMM)
|
||||
clamped_local = local_row.clamp(max=127)
|
||||
dst_rows = expert_assign * 128 + clamped_local
|
||||
dst_rows = padded_expert_offsets[expert_assign] + local_row
|
||||
padded_x_sf[dst_rows, :K_sf] = x_sf
|
||||
|
||||
# Phase 2: Per-expert swizzle and concatenate
|
||||
# Each expert gets at most padded_x_sf[e*128 : (e+1)*128] for the first 128 rows.
|
||||
# For experts with >128 tokens, we'd need multiple chunks, but during
|
||||
# cudagraph capture the token budget is fixed, and the GEMM uses expert_offsets
|
||||
# to determine how many rows each expert gets.
|
||||
#
|
||||
# Strategy: always swizzle 128 rows per expert (fixed loop), zero-pad shorter experts.
|
||||
# The GEMM only reads the rows indicated by expert_offsets.
|
||||
swizzled_parts = []
|
||||
for e in range(num_experts):
|
||||
n_padded = padded_rows_per_expert[e]
|
||||
start = padded_expert_offsets[e]
|
||||
buf = per_expert_bufs[e]
|
||||
buf.zero_()
|
||||
# Copy from padded_x_sf at this expert's 128-aligned offset
|
||||
# Always copy 128 rows (fixed shape for cudagraph)
|
||||
src_start = e * 128
|
||||
buf[:, :K_sf] = padded_x_sf[src_start:src_start + 128]
|
||||
swizzled = pad_and_swizzle_single(buf)
|
||||
swizzled_parts.append(swizzled)
|
||||
# Process in 128-row chunks
|
||||
offset = start
|
||||
remaining = n_padded
|
||||
while remaining > 0:
|
||||
buf.zero_()
|
||||
chunk = min(remaining, 128)
|
||||
buf[:chunk, :K_sf] = padded_x_sf[offset:offset + chunk]
|
||||
swizzled = pad_and_swizzle_single(buf)
|
||||
swizzled_parts.append(swizzled)
|
||||
offset += 128
|
||||
remaining -= 128
|
||||
if n_padded == 0:
|
||||
buf.zero_()
|
||||
swizzled = pad_and_swizzle_single(buf)
|
||||
swizzled_parts.append(swizzled)
|
||||
|
||||
# Concatenate all expert blocks (byte-reinterpretable)
|
||||
all_flat = torch.cat([p.view(torch.uint8) for p in swizzled_parts], dim=0)
|
||||
all_flat = all_flat.view(torch.float8_e4m3fn)
|
||||
return all_flat.reshape(num_experts * 128, -1)
|
||||
total_padded = padded_expert_offsets[num_experts]
|
||||
return all_flat.reshape(total_padded, -1)
|
||||
|
||||
def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids):
|
||||
"""Compute activation global scales from a warmup forward pass.
|
||||
@@ -345,20 +343,11 @@ class CuTeDSLMoERunner:
|
||||
"""Run the NVFP4 MoE forward pass.
|
||||
|
||||
Handles global→local expert ID remapping for expert parallelism.
|
||||
topk_ids contains GLOBAL expert IDs (0..n_routed_experts-1).
|
||||
This runner only handles local experts
|
||||
[experts_start_idx, experts_start_idx + num_experts).
|
||||
|
||||
Non-local tokens get zero weight and are clamped to expert 0
|
||||
(harmless — zero-weighted output contributes nothing to scatter_add).
|
||||
|
||||
Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes.
|
||||
|
||||
PADDING STRATEGY: Each expert is padded to exactly 128 tokens in
|
||||
the GEMM input. This ensures scale_a has a fixed layout (128 rows
|
||||
per expert) that matches the padded expert_offsets. The GEMM output
|
||||
for padding tokens is discarded by scatter_add (sorted_token_ids
|
||||
only refers to real tokens).
|
||||
Each expert's slots are padded to multiples of 128 for the GEMM.
|
||||
expert_offsets is [0, padded_e0, padded_e0+padded_e1, ...].
|
||||
scale_a is produced at those same offsets.
|
||||
"""
|
||||
num_tokens = hidden_states.shape[0]
|
||||
top_k = topk_ids.shape[1]
|
||||
@@ -381,28 +370,36 @@ class CuTeDSLMoERunner:
|
||||
sort_idx = flat_ids.argsort(stable=True)
|
||||
sorted_ids = flat_ids[sort_idx]
|
||||
sorted_weights = flat_weights[sort_idx]
|
||||
sorted_token_ids = token_indices[sort_idx] # GPU tensor, no .cpu()
|
||||
sorted_token_ids = token_indices[sort_idx]
|
||||
|
||||
# Expert offsets (GPU-only, never touches CPU)
|
||||
# Expert offsets (real token counts)
|
||||
expert_id_range = self._expert_id_range
|
||||
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.zero_()
|
||||
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
|
||||
|
||||
# -- Pad to 128 tokens per expert --
|
||||
# The GEMM and scale assembly expect a fixed layout where each expert
|
||||
# gets exactly 128 rows. Pad slot_hidden to num_experts * 128 rows,
|
||||
# and set expert_offsets to [0, 128, 256, ...].
|
||||
# Padded expert offsets (each expert padded to 128)
|
||||
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
|
||||
padded_expert_offsets = self._padded_expert_offsets_buf
|
||||
# Already filled in _allocate_buffers with [0, 128, 256, ...]
|
||||
padded_expert_offsets.zero_()
|
||||
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
|
||||
total_padded_slots = padded_expert_offsets[self.num_experts]
|
||||
|
||||
# Gather hidden states into slot order, then pad to 128 per expert
|
||||
# -- Gather hidden states into slot order, pad to 128 per expert --
|
||||
slot_hidden = hidden_states[sorted_token_ids]
|
||||
padded_size = self.num_experts * 128
|
||||
padded_hidden = self._padded_hidden_buf[:padded_size]
|
||||
padded_hidden = self._padded_hidden_buf[:total_padded_slots]
|
||||
padded_hidden.zero_()
|
||||
padded_hidden[:num_slots] = slot_hidden
|
||||
# Scatter real tokens into padded positions
|
||||
# Each expert e: tokens are at [expert_offsets[e], expert_offsets[e+1])
|
||||
# In padded buffer: tokens are at [padded_expert_offsets[e], padded_expert_offsets[e]+tokens_per_expert[e])
|
||||
row_indices = self._row_indices_buf[:num_slots]
|
||||
expert_assign = torch.searchsorted(
|
||||
expert_offsets[1:], row_indices, right=True
|
||||
).clamp(max=self.num_experts - 1)
|
||||
local_row = row_indices - expert_offsets[expert_assign]
|
||||
padded_dst = padded_expert_offsets[expert_assign] + local_row
|
||||
padded_hidden[padded_dst] = slot_hidden
|
||||
|
||||
# === L1: gate + up ===
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(
|
||||
@@ -410,8 +407,7 @@ class CuTeDSLMoERunner:
|
||||
)
|
||||
|
||||
l1_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
x_sf[:num_slots], # Only real tokens have scale data
|
||||
expert_offsets[:self.num_experts + 1],
|
||||
x_sf, expert_offsets[:self.num_experts + 1],
|
||||
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
|
||||
)
|
||||
l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
|
||||
@@ -423,27 +419,25 @@ class CuTeDSLMoERunner:
|
||||
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
||||
)
|
||||
|
||||
# Only keep real token outputs (padding rows are garbage)
|
||||
l1_out = l1_out[:num_slots]
|
||||
# Extract real token outputs from padded GEMM output
|
||||
l1_out_real = l1_out[padded_dst]
|
||||
|
||||
# === SiLU(gate) * up ===
|
||||
gate = l1_out[:, :self.intermediate_size]
|
||||
up = l1_out[:, self.intermediate_size:]
|
||||
gate = l1_out_real[:, :self.intermediate_size]
|
||||
up = l1_out_real[:, self.intermediate_size:]
|
||||
activated = torch.nn.functional.silu(gate) * up
|
||||
|
||||
# === L2: down ===
|
||||
# Pad activated to 128 tokens per expert for L2 GEMM
|
||||
padded_activated = self._padded_activated_buf[:padded_size]
|
||||
padded_activated = self._padded_activated_buf[:total_padded_slots]
|
||||
padded_activated.zero_()
|
||||
padded_activated[:num_slots] = activated
|
||||
padded_activated[padded_dst] = activated
|
||||
|
||||
l2_x_fp4, l2_x_sf = quantize_activation_nvfp4(
|
||||
padded_activated, self._l2_activation_global_scale
|
||||
)
|
||||
|
||||
l2_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
l2_x_sf[:num_slots],
|
||||
expert_offsets[:self.num_experts + 1],
|
||||
l2_x_sf, expert_offsets[:self.num_experts + 1],
|
||||
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
|
||||
)
|
||||
l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
|
||||
@@ -455,13 +449,12 @@ class CuTeDSLMoERunner:
|
||||
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
|
||||
)
|
||||
|
||||
# Only keep real token outputs
|
||||
l2_out = l2_out[:num_slots]
|
||||
l2_out_real = l2_out[padded_dst]
|
||||
|
||||
# === Scatter -> final output ===
|
||||
y = self._output_buf[:num_tokens]
|
||||
y.zero_()
|
||||
weighted_out = l2_out * sorted_weights.unsqueeze(1).to(l2_out.dtype)
|
||||
weighted_out = l2_out_real * sorted_weights.unsqueeze(1).to(l2_out_real.dtype)
|
||||
y.scatter_add_(
|
||||
0,
|
||||
sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size),
|
||||
|
||||
Reference in New Issue
Block a user