Bump Flashinfer to v0.6.1 (#30993)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
@@ -416,7 +416,7 @@ class TRTLLMPrefill:
|
||||
|
||||
max_q_len: int
|
||||
"""
|
||||
The maximum query length *among prefill requests*.
|
||||
The maximum query length *among prefill requests*.
|
||||
"""
|
||||
|
||||
max_seq_len: int
|
||||
@@ -1051,6 +1051,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
o_data_type=self.model_config.dtype,
|
||||
fixed_split_size=self.prefill_fixed_split_size,
|
||||
disable_split_kv=self.disable_split_kv,
|
||||
)
|
||||
@@ -1099,6 +1100,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.kv_cache_dtype,
|
||||
o_data_type=self.model_config.dtype,
|
||||
fixed_split_size=self.decode_fixed_split_size,
|
||||
disable_split_kv=self.disable_split_kv,
|
||||
)
|
||||
@@ -1568,6 +1570,7 @@ def fast_plan_decode(
|
||||
logits_soft_cap: float | None = None,
|
||||
q_data_type: str | torch.dtype | None = "float16",
|
||||
kv_data_type: str | torch.dtype | None = None,
|
||||
o_data_type: str | torch.dtype | None = None,
|
||||
data_type: str | torch.dtype | None = None,
|
||||
sm_scale: float | None = None,
|
||||
rope_scale: float | None = None,
|
||||
@@ -1606,6 +1609,7 @@ def fast_plan_decode(
|
||||
logits_soft_cap,
|
||||
q_data_type,
|
||||
kv_data_type,
|
||||
o_data_type,
|
||||
data_type,
|
||||
sm_scale,
|
||||
rope_scale,
|
||||
@@ -1663,7 +1667,7 @@ def fast_plan_decode(
|
||||
|
||||
try:
|
||||
# Make sure we pass exactly 19 arguments for tensor core version
|
||||
self._plan_info = self._cached_module.plan(
|
||||
args = [
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
@@ -1680,9 +1684,13 @@ def fast_plan_decode(
|
||||
head_dim,
|
||||
False, # causal
|
||||
window_left,
|
||||
fixed_split_size,
|
||||
disable_split_kv,
|
||||
0,
|
||||
]
|
||||
if self._backend == "fa2":
|
||||
args.append(fixed_split_size)
|
||||
args.append(disable_split_kv)
|
||||
args.append(0) # num_colocated_ctas
|
||||
self._plan_info = self._cached_module.plan(
|
||||
*args,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in tensor core plan: {e}") from e
|
||||
|
||||
Reference in New Issue
Block a user