Add B1 weight/format verification at L0 in single_shot

This commit is contained in:
2026-06-03 01:52:55 +00:00
parent 8df5de5477
commit 6bca7f66f3

View File

@@ -884,6 +884,15 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
# 6. Production FMHA — B1 mixed FP8/BF16 decode path.
_pt('fmha_start')
if li == 0:
print(f" L0 B1 verify: kv_nope_fp8 dtype={kv_nope_fp8.dtype} shape={tuple(kv_nope_fp8.shape)} "
f"kv_nope_scale dtype={kv_nope_scale.dtype} shape={tuple(kv_nope_scale.shape)} "
f"kv_rope_bf16 dtype={kv_rope_bf16.dtype} shape={tuple(kv_rope_bf16.shape)}", flush=True)
assert kv_nope_fp8.dtype in (torch.uint8, torch.float8_e4m3fn), f"kv_nope_fp8 wrong dtype: {kv_nope_fp8.dtype}"
assert kv_nope_scale.dtype == torch.float32, f"kv_nope_scale wrong dtype: {kv_nope_scale.dtype}"
assert kv_rope_bf16.dtype == torch.bfloat16, f"kv_rope_bf16 wrong dtype: {kv_rope_bf16.dtype}"
assert kv_nope_fp8.shape[-1] == nope_dim, f"kv_nope_fp8 dim={kv_nope_fp8.shape[-1]} != nope_dim={nope_dim}"
assert kv_rope_bf16.shape[-1] == rd, f"kv_rope_bf16 dim={kv_rope_bf16.shape[-1]} != rope_dim={rd}"
if VERBOSE >= 2 and li < 3:
print(f" L{li} FMHA mixed input: T={T} seq_len={seq_len} hd={hd} n_h={n_h} n_comp={kv_cache.n_comp} swa_len={swa_len}", flush=True)
attn_out = _run_production_fmha_mixed(