fix: per-expert scale assembly (match assemble_scales_2d_side)
This commit is contained in:
@@ -78,8 +78,6 @@ class CuTeDSLMoERunner:
|
||||
max_slots = self.max_num_tokens * self.top_k
|
||||
K_sf = cutedsl_ceil_div(self.hidden_size, 16)
|
||||
padded_cols = cutedsl_ceil_div(K_sf, 4) * 4
|
||||
# Worst case: 1 token per expert, each padded to 128 rows
|
||||
max_padded_rows = self.num_experts * 128
|
||||
|
||||
# Slot -> token mapping: [0,0,...,0, 1,1,...,1, ...] (top_k repeats)
|
||||
self._token_indices = torch.arange(
|
||||
@@ -91,12 +89,13 @@ class CuTeDSLMoERunner:
|
||||
self._expert_offsets_buf = torch.zeros(
|
||||
self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self._padded_expert_offsets_buf = torch.zeros(
|
||||
self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self._padded_scales_buf = torch.zeros(
|
||||
max_padded_rows, padded_cols, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn)
|
||||
|
||||
# Per-expert scale buffers: each expert gets a 128-row block
|
||||
# This matches assemble_scales_2d_side which pads+swizzles each expert independently
|
||||
self._per_expert_scale_bufs = [
|
||||
torch.zeros(128, padded_cols, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
|
||||
for _ in range(self.num_experts)
|
||||
]
|
||||
|
||||
self._buffers_allocated = True
|
||||
|
||||
@@ -145,46 +144,40 @@ class CuTeDSLMoERunner:
|
||||
def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets):
|
||||
"""Assemble 2D-side activation scales (cudagraph-safe, no CPU sync).
|
||||
|
||||
Uses GPU-computed indices to scatter scale data into padded positions,
|
||||
then applies the swizzle. Returns 2D tensor.
|
||||
Matches the working assemble_scales_2d_side: pads each expert's scales
|
||||
to 128 rows, swizzles each expert block independently, then concatenates.
|
||||
No .item(), no .tolist(), no Python control flow on GPU data.
|
||||
|
||||
Fixed-shape: each expert gets exactly 128 rows (padded). We always
|
||||
copy the full 128-row block from x_sf (zero-padded rows are harmless).
|
||||
"""
|
||||
num_experts = self.num_experts
|
||||
K_sf = x_sf.shape[1]
|
||||
padded_cols = cutedsl_ceil_div(K_sf, 4) * 4
|
||||
|
||||
# Compute tokens per expert (GPU)
|
||||
tokens_per_expert = expert_offsets[1:] - expert_offsets[:-1]
|
||||
# For each expert: zero the buffer, scatter its rows, swizzle, flatten
|
||||
swizzled_parts = []
|
||||
for e in range(num_experts):
|
||||
buf = self._per_expert_scale_bufs[e]
|
||||
buf.zero_()
|
||||
|
||||
start = expert_offsets[e]
|
||||
# Always copy 128 rows — extra rows will be zeros from x_sf padding
|
||||
# or from the zero-initialized buffer
|
||||
# Use a fixed-shape slice: buf is always (128, padded_cols)
|
||||
# x_sf may not have 128 rows for this expert, but that's fine —
|
||||
# the buffer is zero-initialized and we overwrite with whatever exists
|
||||
buf[:, :K_sf] = x_sf[start:start + 128]
|
||||
|
||||
# Swizzle this expert's block (matches pad_and_swizzle_single per expert)
|
||||
swizzled = pad_and_swizzle_single(buf)
|
||||
swizzled_parts.append(swizzled)
|
||||
|
||||
# Compute padded rows per expert (round up to 128)
|
||||
padded_rows_per_expert = ((tokens_per_expert + 127) // 128) * 128
|
||||
# Concatenate all expert blocks (matches cat_byte_reinterpretable_tensors)
|
||||
# float8_e4m3fn is a 1-byte float type — cat via uint8 view
|
||||
all_flat = torch.cat([p.view(torch.uint8) for p in swizzled_parts], dim=0)
|
||||
all_flat = all_flat.view(torch.float8_e4m3fn)
|
||||
|
||||
# Compute padded offsets
|
||||
padded_expert_offsets = self._padded_expert_offsets_buf
|
||||
padded_expert_offsets.zero_()
|
||||
padded_expert_offsets[1:] = padded_rows_per_expert.cumsum(0)
|
||||
|
||||
# Use the FULL pre-allocated scales buffer (no GPU scalar slicing)
|
||||
padded_scales = self._padded_scales_buf
|
||||
padded_scales.zero_()
|
||||
|
||||
# Build index mapping: for each row in x_sf, which expert does it belong to?
|
||||
total_rows = x_sf.shape[0]
|
||||
row_indices = self._token_indices[:total_rows]
|
||||
expert_assign = torch.searchsorted(
|
||||
expert_offsets[1:], row_indices, right=False
|
||||
).clamp(max=num_experts - 1)
|
||||
|
||||
# Destination row in padded buffer
|
||||
local_row = row_indices - expert_offsets[expert_assign]
|
||||
dst_rows = padded_expert_offsets[expert_assign] + local_row
|
||||
|
||||
# Scatter x_sf into padded_scales
|
||||
padded_scales[dst_rows, :K_sf] = x_sf
|
||||
|
||||
# Apply swizzle, reshape to 2D (element count preserved by swizzle)
|
||||
swizzled = pad_and_swizzle_single(padded_scales)
|
||||
return swizzled.reshape(padded_scales.shape[0], -1)
|
||||
return all_flat.reshape(num_experts * 128, -1)
|
||||
|
||||
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
|
||||
"""Run the NVFP4 MoE forward pass.
|
||||
|
||||
Reference in New Issue
Block a user