Fix NVFP4 attention: slice output to actual N after 128-padding
This commit is contained in:
@@ -219,7 +219,8 @@ class NVFP4Attention:
|
||||
self._cache_key = cache_key
|
||||
|
||||
# Run Q×K^T GEMM
|
||||
scores = self._runner.run(q_2d) # (T*NH, T)
|
||||
scores = self._runner.run(q_2d) # (T*NH, N_padded)
|
||||
scores = scores[:, :T] # Slice to actual N=T (runner pads to 128)
|
||||
scores = scores * scale
|
||||
|
||||
# Causal mask
|
||||
|
||||
Reference in New Issue
Block a user