more fixes6

This commit is contained in:
2026-05-14 20:08:25 +00:00
parent 40b980b9d6
commit 4363eee2ce
2 changed files with 9 additions and 33 deletions

View File

@@ -1576,15 +1576,12 @@ class DeepseekV4Model(nn.Module):
# wo_a: attention forward reads .weight and .weight_scale_inv directly
# for fp8_einsum. Only layer that needs FP8 conversion.
fp8_proj_names = {"wo_a"}
# Attention layers called via .forward() — need bf16
# cuBLAS BF16 is broken on Blackwell — nothing gets dequantized to BF16.
# Everything stays native NVFP4/FP8 via FlashInfer CUTLASS.
bf16_proj_names = set()
bf16_shared_names = set()
# No BF16 dequant paths active — cuBLAS BF16 is broken on Blackwell.
# wo_a goes NVFP4→FP8; compressor gets reconstructed from checkpoint;
# MoE experts stay native NVFP4 via CUTLASS kernel.
fp8_converted = 0
fp8_from_bf16 = 0
bf16_converted = 0
compressor_converted = 0
# Build shard index once for compressor reconstruction (avoids N×M full-shard loads)
@@ -1608,16 +1605,6 @@ class DeepseekV4Model(nn.Module):
self._convert_bf16_to_fp8(mod, FP8_MAX)
fp8_from_bf16 += 1
# BF16 conversion: attention layers via .forward()
for proj_name in bf16_proj_names:
if not hasattr(attn, proj_name):
continue
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
continue
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
bf16_converted += 1
# Compressor: fused_wkv_wgate used via direct torch.mm
# Compressor weights were SKIPPED during loading (skip patterns)
# because the stacking weight_loader corrupts NVFP4 uint8 data.
@@ -1639,25 +1626,12 @@ class DeepseekV4Model(nn.Module):
if idx_compressor is not None and hasattr(idx_compressor, "fused_wkv_wgate"):
compressor_converted += self._reconstruct_compressor_weight(
idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer", _shard_index=_shard_index)
# Shared experts: dequantize NVFP4 → BF16
ffn = layer.ffn
if hasattr(ffn, "shared_experts") and ffn.shared_experts is not None:
for proj_name in bf16_shared_names:
if not hasattr(ffn.shared_experts, proj_name):
continue
mod = getattr(ffn.shared_experts, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
continue
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
bf16_converted += 1
total_fp8 = fp8_converted + fp8_from_bf16
total_bf16 = bf16_converted + compressor_converted
total_bf16 = compressor_converted
if int(os.environ.get('NVFP4_DEBUG', '0')) and (total_fp8 > 0 or total_bf16 > 0):
print(f"NVFP4 post-load: {fp8_converted} NVFP4->FP8, "
f"{fp8_from_bf16} BF16->FP8, "
f"{bf16_converted} attn/shared->BF16, "
f"{compressor_converted} compressor->BF16")
@@ -1914,7 +1888,6 @@ class DeepseekV4Model(nn.Module):
if hasattr(fused_mod, attr):
delattr(fused_mod, attr)
return 1
return 0
def _convert_bf16_to_fp8(self, mod, fp8_max):
"""Convert BF16 weight to FP8 for fp8_einsum path.

View File

@@ -150,8 +150,11 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
# Pack E2M1 pairs into single bytes (2 per byte, low nibble first)
# mxf4nvf4 reads FP4 packed from SMEM — must match kernel's TMA layout
e2m1_lo = e2m1_4bit[0::2] # even indices → low nibble
e2m1_hi = e2m1_4bit[1::2] # odd indices → high nibble
# Reshape to pairs instead of strided indexing (Triton doesn't support
# [0::2] on reshaped tensors — unsupported tensor index error)
e2m1_pairs = tl.reshape(e2m1_4bit, [BLOCK_K // 2, 2])
e2m1_lo = e2m1_pairs[:, 0] # even indices → low nibble
e2m1_hi = e2m1_pairs[:, 1] # odd indices → high nibble
e2m1_packed = (e2m1_hi << 4 | e2m1_lo).to(tl.uint8) # [BLOCK_K // 2]
k_offsets_out = k_block_id * (BLOCK_K // 2) + tl.arange(0, BLOCK_K // 2)