[Core] Always use tensor cores for Flashinfer Decode Wrapper (#23214)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -13,7 +13,6 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
@@ -228,8 +227,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.q_data_type = self.kv_cache_dtype
|
||||
else:
|
||||
self.kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
self.use_tensor_cores = (envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or
|
||||
(self.num_qo_heads // self.num_kv_heads > 4))
|
||||
|
||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||
|
||||
@@ -308,7 +305,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_indptr_buffer=paged_kv_indptr,
|
||||
paged_kv_indices_buffer=paged_kv_indices,
|
||||
paged_kv_last_page_len_buffer=paged_kv_last_page_len,
|
||||
use_tensor_cores=self.use_tensor_cores)
|
||||
# Tensor cores are enabled by default because the perf would be
|
||||
# atleast as good as cuda cores for all attention ops in latest
|
||||
# gpus.
|
||||
use_tensor_cores=True,
|
||||
)
|
||||
|
||||
# save the decode wrapper
|
||||
if use_cudagraph:
|
||||
@@ -984,52 +985,29 @@ def fast_plan_decode(
|
||||
self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu,
|
||||
non_blocking=True)
|
||||
|
||||
if self.use_tensor_cores:
|
||||
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
||||
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
||||
|
||||
try:
|
||||
# Make sure we pass exactly 15 arguments for tensor core version
|
||||
self._plan_info = self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_host,
|
||||
indptr_cpu,
|
||||
seq_lens_cpu,
|
||||
batch_size, # total_num_rows
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
head_dim,
|
||||
head_dim,
|
||||
False, # causal
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in tensor core plan: {e}") from e
|
||||
else:
|
||||
try:
|
||||
# Make sure we pass exactly 15 arguments for standard version
|
||||
self._plan_info = self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
indptr_cpu,
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
window_left,
|
||||
logits_soft_cap,
|
||||
head_dim,
|
||||
head_dim,
|
||||
torch.empty(0, dtype=q_data_type),
|
||||
torch.empty(0, dtype=kv_data_type),
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in standard plan: {e}") from e
|
||||
try:
|
||||
# Make sure we pass exactly 15 arguments for tensor core version
|
||||
self._plan_info = self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_host,
|
||||
indptr_cpu,
|
||||
seq_lens_cpu,
|
||||
batch_size, # total_num_rows
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
head_dim,
|
||||
head_dim,
|
||||
False, # causal
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in tensor core plan: {e}") from e
|
||||
|
||||
self._pos_encoding_mode = pos_encoding_mode
|
||||
self._window_left = window_left
|
||||
|
||||
Reference in New Issue
Block a user