Bump Flashinfer to v0.6.1 (#30993)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv
2026-01-22 00:49:50 +08:00
committed by GitHub
parent 1861ae8aae
commit 808d6fd7b9
12 changed files with 20 additions and 73 deletions

View File

@@ -10,9 +10,6 @@ from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
@@ -167,7 +164,6 @@ def flashinfer_fused_moe_blockscale_fp8(
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
tile_tokens_dim=None,
routing_method_type=routing_method_type,
use_shuffled_weight=False,
)
@@ -255,9 +251,6 @@ def fi_trtllm_fp8_per_tensor_moe(
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
tile_tokens_dim=calculate_tile_tokens_dim(
hidden_states.shape[0], top_k, num_experts
),
routing_method_type=routing_method_type,
)

View File

@@ -160,7 +160,6 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
"local_expert_offset": local_expert_offset,
"local_num_experts": local_num_experts,
"routed_scaling_factor": None,
"tile_tokens_dim": None,
"routing_method_type": 1,
"do_finalize": True,
"output": output,

View File

@@ -982,8 +982,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.intermediate_size, # padded to multiple of 256
layer.ep_rank * layer.local_num_experts, # local_expert_offset
self.num_experts, # local num experts
None,
None,
None, # routed_scaling_factor
1 if layer.renormalize else 0, # routing_method_type, renormalize
True, # do finalize
tune_max_num_tokens=max(self.max_capture_size, 1),

View File

@@ -392,7 +392,6 @@ def flashinfer_trtllm_fp4_moe(
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
tile_tokens_dim=None,
routing_method_type=routing_method_type,
do_finalize=True,
)[0]
@@ -478,7 +477,6 @@ def flashinfer_trtllm_fp4_routed_moe(
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
tile_tokens_dim=None,
routing_method_type=1,
do_finalize=True,
)[0]

View File

@@ -25,30 +25,6 @@ class FlashinferMoeBackend(Enum):
CUTEDSL = "CUTEDSL"
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
from flashinfer import next_positive_power_of_2
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
# with the necessary kernels is released.
tile_tokens_dim = 8
# A factor considering tokens are not perfectly balanced among experts.
imbalance_factor = 1.3
# Calculate the number of tokens per expert
# assuming perfect distribution.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-max_tile_tokens_dim tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
return (
x.reshape(-1, 2, x.shape[-2] // 2, x.shape[-1]).flip(dims=[1]).reshape(x.shape)

View File

@@ -416,7 +416,7 @@ class TRTLLMPrefill:
max_q_len: int
"""
The maximum query length *among prefill requests*.
The maximum query length *among prefill requests*.
"""
max_seq_len: int
@@ -1051,6 +1051,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
o_data_type=self.model_config.dtype,
fixed_split_size=self.prefill_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
@@ -1099,6 +1100,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
o_data_type=self.model_config.dtype,
fixed_split_size=self.decode_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
@@ -1568,6 +1570,7 @@ def fast_plan_decode(
logits_soft_cap: float | None = None,
q_data_type: str | torch.dtype | None = "float16",
kv_data_type: str | torch.dtype | None = None,
o_data_type: str | torch.dtype | None = None,
data_type: str | torch.dtype | None = None,
sm_scale: float | None = None,
rope_scale: float | None = None,
@@ -1606,6 +1609,7 @@ def fast_plan_decode(
logits_soft_cap,
q_data_type,
kv_data_type,
o_data_type,
data_type,
sm_scale,
rope_scale,
@@ -1663,7 +1667,7 @@ def fast_plan_decode(
try:
# Make sure we pass exactly 19 arguments for tensor core version
self._plan_info = self._cached_module.plan(
args = [
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
@@ -1680,9 +1684,13 @@ def fast_plan_decode(
head_dim,
False, # causal
window_left,
fixed_split_size,
disable_split_kv,
0,
]
if self._backend == "fa2":
args.append(fixed_split_size)
args.append(disable_split_kv)
args.append(0) # num_colocated_ctas
self._plan_info = self._cached_module.plan(
*args,
)
except Exception as e:
raise RuntimeError(f"Error in tensor core plan: {e}") from e