[BUGFIX] Fix degenerate strides in TRTLLM query tensors for FlashInfer backend. Fixes issue #32353 (#32417)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@@ -1385,8 +1385,11 @@ class FlashInferImpl(AttentionImpl):
|
||||
)
|
||||
else:
|
||||
assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
|
||||
# prefill_query may be non-contiguous
|
||||
prefill_query = prefill_query.contiguous()
|
||||
# prefill_query may be non-contiguous or have degenerate strides
|
||||
# First ensure memory contiguity, then fix degenerate strides
|
||||
# with reshape. contiguous() alone doesn't fix degenerate
|
||||
# strides when a dimension has size 1.
|
||||
prefill_query = prefill_query.contiguous().reshape(prefill_query.shape)
|
||||
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
||||
block_tables_prefill = attn_metadata.prefill.block_tables
|
||||
seq_lens_prefill = attn_metadata.prefill.seq_lens
|
||||
@@ -1495,9 +1498,12 @@ class FlashInferImpl(AttentionImpl):
|
||||
out=output[:num_decode_tokens],
|
||||
)
|
||||
else:
|
||||
# decode_query may be non-contiguous
|
||||
# decode_query may be non-contiguous or have degenerate strides
|
||||
assert isinstance(attn_metadata.decode, TRTLLMDecode)
|
||||
decode_query = decode_query.contiguous()
|
||||
# First ensure memory contiguity, then fix degenerate strides
|
||||
# with reshape. contiguous() alone doesn't fix degenerate
|
||||
# strides when a dimension has size 1.
|
||||
decode_query = decode_query.contiguous().reshape(decode_query.shape)
|
||||
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
||||
block_tables_decode = attn_metadata.decode.block_tables
|
||||
seq_lens_decode = attn_metadata.decode.seq_lens
|
||||
|
||||
Reference in New Issue
Block a user