[BUGFIX] Fix degenerate strides in TRTLLM query tensors for FlashInfer backend. Fixes issue #32353 (#32417)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@@ -1385,8 +1385,11 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
|
assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
|
||||||
# prefill_query may be non-contiguous
|
# prefill_query may be non-contiguous or have degenerate strides
|
||||||
prefill_query = prefill_query.contiguous()
|
# First ensure memory contiguity, then fix degenerate strides
|
||||||
|
# with reshape. contiguous() alone doesn't fix degenerate
|
||||||
|
# strides when a dimension has size 1.
|
||||||
|
prefill_query = prefill_query.contiguous().reshape(prefill_query.shape)
|
||||||
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
||||||
block_tables_prefill = attn_metadata.prefill.block_tables
|
block_tables_prefill = attn_metadata.prefill.block_tables
|
||||||
seq_lens_prefill = attn_metadata.prefill.seq_lens
|
seq_lens_prefill = attn_metadata.prefill.seq_lens
|
||||||
@@ -1495,9 +1498,12 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
out=output[:num_decode_tokens],
|
out=output[:num_decode_tokens],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# decode_query may be non-contiguous
|
# decode_query may be non-contiguous or have degenerate strides
|
||||||
assert isinstance(attn_metadata.decode, TRTLLMDecode)
|
assert isinstance(attn_metadata.decode, TRTLLMDecode)
|
||||||
decode_query = decode_query.contiguous()
|
# First ensure memory contiguity, then fix degenerate strides
|
||||||
|
# with reshape. contiguous() alone doesn't fix degenerate
|
||||||
|
# strides when a dimension has size 1.
|
||||||
|
decode_query = decode_query.contiguous().reshape(decode_query.shape)
|
||||||
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
workspace_buffer = _get_trtllm_gen_workspace_buffer()
|
||||||
block_tables_decode = attn_metadata.decode.block_tables
|
block_tables_decode = attn_metadata.decode.block_tables
|
||||||
seq_lens_decode = attn_metadata.decode.seq_lens
|
seq_lens_decode = attn_metadata.decode.seq_lens
|
||||||
|
|||||||
Reference in New Issue
Block a user