diff --git a/tests/test_nvfp4_attn_gemm_b200.py b/tests/test_nvfp4_attn_gemm_b200.py index 6673a607..ebfac828 100644 --- a/tests/test_nvfp4_attn_gemm_b200.py +++ b/tests/test_nvfp4_attn_gemm_b200.py @@ -219,7 +219,8 @@ class NVFP4Attention: self._cache_key = cache_key # Run Q×K^T GEMM - scores = self._runner.run(q_2d) # (T*NH, T) + scores = self._runner.run(q_2d) # (T*NH, N_padded) + scores = scores[:, :T] # Slice to actual N=T (runner pads to 128) scores = scores * scale # Causal mask