diff --git a/single_shot_inference.py b/single_shot_inference.py index 36866850..636824cb 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -884,6 +884,15 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, # 6. Production FMHA — B1 mixed FP8/BF16 decode path. _pt('fmha_start') + if li == 0: + print(f" L0 B1 verify: kv_nope_fp8 dtype={kv_nope_fp8.dtype} shape={tuple(kv_nope_fp8.shape)} " + f"kv_nope_scale dtype={kv_nope_scale.dtype} shape={tuple(kv_nope_scale.shape)} " + f"kv_rope_bf16 dtype={kv_rope_bf16.dtype} shape={tuple(kv_rope_bf16.shape)}", flush=True) + assert kv_nope_fp8.dtype in (torch.uint8, torch.float8_e4m3fn), f"kv_nope_fp8 wrong dtype: {kv_nope_fp8.dtype}" + assert kv_nope_scale.dtype == torch.float32, f"kv_nope_scale wrong dtype: {kv_nope_scale.dtype}" + assert kv_rope_bf16.dtype == torch.bfloat16, f"kv_rope_bf16 wrong dtype: {kv_rope_bf16.dtype}" + assert kv_nope_fp8.shape[-1] == nope_dim, f"kv_nope_fp8 dim={kv_nope_fp8.shape[-1]} != nope_dim={nope_dim}" + assert kv_rope_bf16.shape[-1] == rd, f"kv_rope_bf16 dim={kv_rope_bf16.shape[-1]} != rope_dim={rd}" if VERBOSE >= 2 and li < 3: print(f" L{li} FMHA mixed input: T={T} seq_len={seq_len} hd={hd} n_h={n_h} n_comp={kv_cache.n_comp} swa_len={swa_len}", flush=True) attn_out = _run_production_fmha_mixed(