diff --git a/tests/test_vllm_codepaths_b200.py b/tests/test_vllm_codepaths_b200.py index 48cf0234..711553f1 100644 --- a/tests/test_vllm_codepaths_b200.py +++ b/tests/test_vllm_codepaths_b200.py @@ -164,7 +164,7 @@ def main(): ) c = F.cosine_similarity(q_rope_ref.flatten().unsqueeze(0).float(), q_test.flatten().unsqueeze(0).float()).item() - print(f" fused_qnorm_rope vs manual: cosine = {c:.6f} {'PASS' if c>=0.999 else 'FAIL'}") + print(f" fused_qnorm_rope vs manual: cosine = {c:.6f} str('PASS' if c>=0.999 else 'FAIL')") # ── Test 2: Verify blackwell_attention_kv_write ─────────── print("\n=== Test 2: blackwell_attention_kv_write ===") @@ -195,7 +195,7 @@ def main(): kv_dequant = kv_dequantize_fp8(kv_read, inv_read) c = F.cosine_similarity(kv_rope_manual.flatten().unsqueeze(0).float(), kv_dequant.flatten().unsqueeze(0).float()).item() - print(f" vllm_kv_write roundtrip: cosine = {c:.6f} {'PASS' if c>=0.99 else 'FAIL'}") + print(f" vllm_kv_write roundtrip: cosine = {c:.6f} str('PASS' if c>=0.99 else 'FAIL')") # ── Test 3: Decode attention using swa_indices ──────────── print("\n=== Test 3: Decode attention with swa_indices ===") @@ -242,7 +242,8 @@ def main(): o_ref_decode = o_ref[-1:] c = F.cosine_similarity(o_decode.flatten().unsqueeze(0).float(), o_ref_decode.flatten().unsqueeze(0).float()).item() - print(f" Decode vs reference cosine: {c:.6f} {'PASS' if c>=0.98 else 'FAIL}") + status = "PASS" if c >= 0.98 else "FAIL" + print(f" Decode vs reference cosine: {c:.6f} {status}") print("\n=== DONE ===")