[Misc] DeepGEMM : Avoid JIT generation in the hot-path (#22215)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-08-08 19:09:59 -04:00
committed by GitHub
parent cd9b9de1fb
commit f703b923f3
5 changed files with 274 additions and 37 deletions

View File

@@ -4,6 +4,9 @@
import functools
import json
import os
# torch.compile needs typing.List. It will fail torch.library.infer_schema
# otherwise
from typing import List # noqa: UP035
from typing import Any, Callable, Optional
import torch
@@ -998,29 +1001,30 @@ def get_config_dtype_str(
return None
def inplace_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None) -> None:
def inplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None: #noqa: UP006
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, is_act_and_mul,
apply_router_weight_on_input, use_fp8_w8a8,
@@ -1082,7 +1086,7 @@ def flashinfer_fused_moe_blockscale_fp8(
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
block_shape: List[int], #noqa: UP006
routed_scaling: float = 1.0) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
assert top_k <= global_num_experts
@@ -1264,7 +1268,8 @@ def outplace_fused_experts(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None) -> torch.Tensor:
block_shape: Optional[List[int]] = None, #noqa: UP006
) -> torch.Tensor:
return fused_experts_impl(
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,