more fixes6
This commit is contained in:
@@ -1576,15 +1576,12 @@ class DeepseekV4Model(nn.Module):
|
||||
# wo_a: attention forward reads .weight and .weight_scale_inv directly
|
||||
# for fp8_einsum. Only layer that needs FP8 conversion.
|
||||
fp8_proj_names = {"wo_a"}
|
||||
# Attention layers called via .forward() — need bf16
|
||||
# cuBLAS BF16 is broken on Blackwell — nothing gets dequantized to BF16.
|
||||
# Everything stays native NVFP4/FP8 via FlashInfer CUTLASS.
|
||||
bf16_proj_names = set()
|
||||
bf16_shared_names = set()
|
||||
# No BF16 dequant paths active — cuBLAS BF16 is broken on Blackwell.
|
||||
# wo_a goes NVFP4→FP8; compressor gets reconstructed from checkpoint;
|
||||
# MoE experts stay native NVFP4 via CUTLASS kernel.
|
||||
|
||||
fp8_converted = 0
|
||||
fp8_from_bf16 = 0
|
||||
bf16_converted = 0
|
||||
compressor_converted = 0
|
||||
|
||||
# Build shard index once for compressor reconstruction (avoids N×M full-shard loads)
|
||||
@@ -1608,16 +1605,6 @@ class DeepseekV4Model(nn.Module):
|
||||
self._convert_bf16_to_fp8(mod, FP8_MAX)
|
||||
fp8_from_bf16 += 1
|
||||
|
||||
# BF16 conversion: attention layers via .forward()
|
||||
for proj_name in bf16_proj_names:
|
||||
if not hasattr(attn, proj_name):
|
||||
continue
|
||||
mod = getattr(attn, proj_name)
|
||||
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
|
||||
continue
|
||||
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
|
||||
bf16_converted += 1
|
||||
|
||||
# Compressor: fused_wkv_wgate used via direct torch.mm
|
||||
# Compressor weights were SKIPPED during loading (skip patterns)
|
||||
# because the stacking weight_loader corrupts NVFP4 uint8 data.
|
||||
@@ -1639,25 +1626,12 @@ class DeepseekV4Model(nn.Module):
|
||||
if idx_compressor is not None and hasattr(idx_compressor, "fused_wkv_wgate"):
|
||||
compressor_converted += self._reconstruct_compressor_weight(
|
||||
idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer", _shard_index=_shard_index)
|
||||
|
||||
# Shared experts: dequantize NVFP4 → BF16
|
||||
ffn = layer.ffn
|
||||
if hasattr(ffn, "shared_experts") and ffn.shared_experts is not None:
|
||||
for proj_name in bf16_shared_names:
|
||||
if not hasattr(ffn.shared_experts, proj_name):
|
||||
continue
|
||||
mod = getattr(ffn.shared_experts, proj_name)
|
||||
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
|
||||
continue
|
||||
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
|
||||
bf16_converted += 1
|
||||
|
||||
total_fp8 = fp8_converted + fp8_from_bf16
|
||||
total_bf16 = bf16_converted + compressor_converted
|
||||
total_bf16 = compressor_converted
|
||||
if int(os.environ.get('NVFP4_DEBUG', '0')) and (total_fp8 > 0 or total_bf16 > 0):
|
||||
print(f"NVFP4 post-load: {fp8_converted} NVFP4->FP8, "
|
||||
f"{fp8_from_bf16} BF16->FP8, "
|
||||
f"{bf16_converted} attn/shared->BF16, "
|
||||
f"{compressor_converted} compressor->BF16")
|
||||
|
||||
|
||||
@@ -1914,7 +1888,6 @@ class DeepseekV4Model(nn.Module):
|
||||
if hasattr(fused_mod, attr):
|
||||
delattr(fused_mod, attr)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def _convert_bf16_to_fp8(self, mod, fp8_max):
|
||||
"""Convert BF16 weight to FP8 for fp8_einsum path.
|
||||
|
||||
@@ -150,8 +150,11 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
|
||||
# Pack E2M1 pairs into single bytes (2 per byte, low nibble first)
|
||||
# mxf4nvf4 reads FP4 packed from SMEM — must match kernel's TMA layout
|
||||
e2m1_lo = e2m1_4bit[0::2] # even indices → low nibble
|
||||
e2m1_hi = e2m1_4bit[1::2] # odd indices → high nibble
|
||||
# Reshape to pairs instead of strided indexing (Triton doesn't support
|
||||
# [0::2] on reshaped tensors — unsupported tensor index error)
|
||||
e2m1_pairs = tl.reshape(e2m1_4bit, [BLOCK_K // 2, 2])
|
||||
e2m1_lo = e2m1_pairs[:, 0] # even indices → low nibble
|
||||
e2m1_hi = e2m1_pairs[:, 1] # odd indices → high nibble
|
||||
e2m1_packed = (e2m1_hi << 4 | e2m1_lo).to(tl.uint8) # [BLOCK_K // 2]
|
||||
|
||||
k_offsets_out = k_block_id * (BLOCK_K // 2) + tl.arange(0, BLOCK_K // 2)
|
||||
|
||||
Reference in New Issue
Block a user