Fix NVFP4 attention: slice output to actual N after 128-padding

This commit is contained in:
2026-05-19 08:55:31 +00:00
parent 42285b6c24
commit c54ddbdae1

View File

@@ -219,7 +219,8 @@ class NVFP4Attention:
self._cache_key = cache_key
# Run Q×K^T GEMM
scores = self._runner.run(q_2d) # (T*NH, T)
scores = self._runner.run(q_2d) # (T*NH, N_padded)
scores = scores[:, :T] # Slice to actual N=T (runner pads to 128)
scores = scores * scale
# Causal mask