FIX Bug 26: quantize slot tokens, not padded buffer

The runner was quantizing the padded_hidden (4096 rows) and then
taking x_sf[:num_slots] (first 48 rows). This only got scales for
expert 0 (the first 48 rows of the padded buffer), not the scales
for tokens scattered across padded positions (expert 1 at row 128, etc).

Fix: quantize slot_hidden (sorted tokens, num_slots rows) to get
correct per-token x_sf, then scatter x_fp4 into padded FP4 buffer
for the GEMM. The scale assembly now receives the correct x_sf.

Added hidden_fp4 and activated_fp4 padded buffers for FP4 scatter.
This commit is contained in:
2026-05-17 21:24:43 +00:00
parent 4d0b6d889d
commit 7256070dd3

View File

@@ -166,9 +166,15 @@ class CuTeDSLMoERunner:
'hidden': torch.zeros(
padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device
),
'hidden_fp4': torch.zeros(
padded_max_slots, self.hidden_size // 2, dtype=torch.float4_e2m1fn_x2, device=self.device
),
'activated': torch.zeros(
padded_max_slots, self.intermediate_size, dtype=torch.bfloat16, device=self.device
),
'activated_fp4': torch.zeros(
padded_max_slots, self.intermediate_size // 2, dtype=torch.float4_e2m1fn_x2, device=self.device
),
})
self._shared_bufs = CuTeDSLMoERunner._shared_padded_bufs[device_key]
@@ -406,35 +412,40 @@ class CuTeDSLMoERunner:
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
total_padded_slots = padded_expert_offsets[self.num_experts]
# -- Gather hidden states into slot order, scatter into padded layout --
# Each expert's tokens go at [padded_expert_offsets[e], padded_expert_offsets[e] + tokens_per_expert[e])
# Padding rows between tokens_per_expert and padded_tokens_per_expert are zero.
# -- Gather hidden states into slot order, compute padded_dst --
slot_hidden = hidden_states[sorted_token_ids]
padded_hidden = self._shared_bufs['hidden']
padded_hidden.zero_()
# scatter: padded_hidden[padded_expert_offsets[expert_assign] + local_row] = slot_hidden
row_indices = self._row_indices_buf[:num_slots]
expert_assign = torch.searchsorted(
expert_offsets[1:], row_indices, right=True
).clamp(max=self.num_experts - 1)
local_row = row_indices - expert_offsets[expert_assign]
padded_dst = padded_expert_offsets[expert_assign] + local_row
padded_hidden[padded_dst] = slot_hidden
# === L1: gate + up ===
x_fp4, x_sf = quantize_activation_nvfp4(
padded_hidden, self._l1_activation_global_scale
# Quantize slot_hidden (sorted tokens), NOT padded_hidden.
# padded_hidden is padded with zeros; quantizing it produces
# x_sf rows at padded positions, but x_sf[:num_slots] would
# only get scales for the first num_slots PADDED rows (expert 0),
# not the scattered token positions. Quantizing slot_hidden
# gives x_sf with num_slots rows (one per token), which the
# scale assembly correctly scatters into padded layout.
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(
slot_hidden, self._l1_activation_global_scale
)
# Scatter x_fp4 into padded layout for the GEMM
padded_x_fp4 = self._shared_bufs['hidden_fp4']
padded_x_fp4.zero_()
padded_x_fp4[padded_dst] = slot_x_fp4
l1_scale_a = self._assemble_scales_cudagraph_safe(
x_sf[:num_slots], expert_offsets[:self.num_experts + 1],
slot_x_sf, expert_offsets[:self.num_experts + 1],
padded_expert_offsets,
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
)
l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
l1_out = run_nvfp4_grouped_gemm(
mat_a=x_fp4, mat_b=self._l1_mat_b,
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
@@ -454,23 +465,23 @@ class CuTeDSLMoERunner:
activated = gate_silu * up
# === L2: down ===
padded_activated = self._shared_bufs['activated']
padded_activated.zero_()
padded_activated[padded_dst] = activated
l2_x_fp4, l2_x_sf = quantize_activation_nvfp4(
padded_activated, self._l2_activation_global_scale
# Quantize activated (per-token), scatter into padded FP4 buffer
slot_l2_x_fp4, slot_l2_x_sf = quantize_activation_nvfp4(
activated, self._l2_activation_global_scale
)
padded_activated_fp4 = self._shared_bufs['activated_fp4']
padded_activated_fp4.zero_()
padded_activated_fp4[padded_dst] = slot_l2_x_fp4
l2_scale_a = self._assemble_scales_cudagraph_safe(
l2_x_sf[:num_slots], expert_offsets[:self.num_experts + 1],
slot_l2_x_sf, expert_offsets[:self.num_experts + 1],
padded_expert_offsets,
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
)
l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
l2_out = run_nvfp4_grouped_gemm(
mat_a=l2_x_fp4, mat_b=self._l2_mat_b,
mat_a=padded_activated_fp4, mat_b=self._l2_mat_b,
scale_a=l2_scale_a, scale_b=self._l2_scale_b,
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,