FIX Bug 26: quantize slot tokens, not padded buffer
The runner was quantizing the padded_hidden (4096 rows) and then taking x_sf[:num_slots] (first 48 rows). This only got scales for expert 0 (the first 48 rows of the padded buffer), not the scales for tokens scattered across padded positions (expert 1 at row 128, etc). Fix: quantize slot_hidden (sorted tokens, num_slots rows) to get correct per-token x_sf, then scatter x_fp4 into padded FP4 buffer for the GEMM. The scale assembly now receives the correct x_sf. Added hidden_fp4 and activated_fp4 padded buffers for FP4 scatter.
This commit is contained in:
@@ -166,9 +166,15 @@ class CuTeDSLMoERunner:
|
||||
'hidden': torch.zeros(
|
||||
padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
||||
),
|
||||
'hidden_fp4': torch.zeros(
|
||||
padded_max_slots, self.hidden_size // 2, dtype=torch.float4_e2m1fn_x2, device=self.device
|
||||
),
|
||||
'activated': torch.zeros(
|
||||
padded_max_slots, self.intermediate_size, dtype=torch.bfloat16, device=self.device
|
||||
),
|
||||
'activated_fp4': torch.zeros(
|
||||
padded_max_slots, self.intermediate_size // 2, dtype=torch.float4_e2m1fn_x2, device=self.device
|
||||
),
|
||||
})
|
||||
self._shared_bufs = CuTeDSLMoERunner._shared_padded_bufs[device_key]
|
||||
|
||||
@@ -406,35 +412,40 @@ class CuTeDSLMoERunner:
|
||||
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
|
||||
total_padded_slots = padded_expert_offsets[self.num_experts]
|
||||
|
||||
# -- Gather hidden states into slot order, scatter into padded layout --
|
||||
# Each expert's tokens go at [padded_expert_offsets[e], padded_expert_offsets[e] + tokens_per_expert[e])
|
||||
# Padding rows between tokens_per_expert and padded_tokens_per_expert are zero.
|
||||
# -- Gather hidden states into slot order, compute padded_dst --
|
||||
slot_hidden = hidden_states[sorted_token_ids]
|
||||
padded_hidden = self._shared_bufs['hidden']
|
||||
padded_hidden.zero_()
|
||||
# scatter: padded_hidden[padded_expert_offsets[expert_assign] + local_row] = slot_hidden
|
||||
row_indices = self._row_indices_buf[:num_slots]
|
||||
expert_assign = torch.searchsorted(
|
||||
expert_offsets[1:], row_indices, right=True
|
||||
).clamp(max=self.num_experts - 1)
|
||||
local_row = row_indices - expert_offsets[expert_assign]
|
||||
padded_dst = padded_expert_offsets[expert_assign] + local_row
|
||||
padded_hidden[padded_dst] = slot_hidden
|
||||
|
||||
# === L1: gate + up ===
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(
|
||||
padded_hidden, self._l1_activation_global_scale
|
||||
# Quantize slot_hidden (sorted tokens), NOT padded_hidden.
|
||||
# padded_hidden is padded with zeros; quantizing it produces
|
||||
# x_sf rows at padded positions, but x_sf[:num_slots] would
|
||||
# only get scales for the first num_slots PADDED rows (expert 0),
|
||||
# not the scattered token positions. Quantizing slot_hidden
|
||||
# gives x_sf with num_slots rows (one per token), which the
|
||||
# scale assembly correctly scatters into padded layout.
|
||||
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(
|
||||
slot_hidden, self._l1_activation_global_scale
|
||||
)
|
||||
# Scatter x_fp4 into padded layout for the GEMM
|
||||
padded_x_fp4 = self._shared_bufs['hidden_fp4']
|
||||
padded_x_fp4.zero_()
|
||||
padded_x_fp4[padded_dst] = slot_x_fp4
|
||||
|
||||
l1_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
x_sf[:num_slots], expert_offsets[:self.num_experts + 1],
|
||||
slot_x_sf, expert_offsets[:self.num_experts + 1],
|
||||
padded_expert_offsets,
|
||||
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
|
||||
)
|
||||
l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
|
||||
|
||||
l1_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=x_fp4, mat_b=self._l1_mat_b,
|
||||
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
||||
@@ -454,23 +465,23 @@ class CuTeDSLMoERunner:
|
||||
activated = gate_silu * up
|
||||
|
||||
# === L2: down ===
|
||||
padded_activated = self._shared_bufs['activated']
|
||||
padded_activated.zero_()
|
||||
padded_activated[padded_dst] = activated
|
||||
|
||||
l2_x_fp4, l2_x_sf = quantize_activation_nvfp4(
|
||||
padded_activated, self._l2_activation_global_scale
|
||||
# Quantize activated (per-token), scatter into padded FP4 buffer
|
||||
slot_l2_x_fp4, slot_l2_x_sf = quantize_activation_nvfp4(
|
||||
activated, self._l2_activation_global_scale
|
||||
)
|
||||
padded_activated_fp4 = self._shared_bufs['activated_fp4']
|
||||
padded_activated_fp4.zero_()
|
||||
padded_activated_fp4[padded_dst] = slot_l2_x_fp4
|
||||
|
||||
l2_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
l2_x_sf[:num_slots], expert_offsets[:self.num_experts + 1],
|
||||
slot_l2_x_sf, expert_offsets[:self.num_experts + 1],
|
||||
padded_expert_offsets,
|
||||
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
|
||||
)
|
||||
l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
|
||||
|
||||
l2_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=l2_x_fp4, mat_b=self._l2_mat_b,
|
||||
mat_a=padded_activated_fp4, mat_b=self._l2_mat_b,
|
||||
scale_a=l2_scale_a, scale_b=self._l2_scale_b,
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
|
||||
|
||||
Reference in New Issue
Block a user