[Misc] DeepGemmExperts : Avoid JIT generation in the hot-path (#21955)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
57393715e8
commit
a65f46be5e
@@ -4,7 +4,9 @@ import functools
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as env
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
@@ -17,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils import has_deep_gemm, run_once
|
||||
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -82,6 +84,65 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
|
||||
return True
|
||||
|
||||
|
||||
@run_once
|
||||
def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
num_topk: int):
|
||||
"""
|
||||
DeepGemm JITs the grouped-gemm kernels. The JIT'ing happens based on the
|
||||
input tensor shapes. In this function, we construct all possible input
|
||||
tensor shapes so all the kernels are JIT'ed and cached.
|
||||
Note that this warmup is expected to happen during the model profile
|
||||
call and not during actual model inference.
|
||||
"""
|
||||
|
||||
assert w1.size(0) == w2.size(0), (
|
||||
"w1 and w2 must have the same number of experts")
|
||||
|
||||
block_m = deep_gemm_block_shape()[0]
|
||||
num_experts = w1.size(0)
|
||||
device = w1.device
|
||||
|
||||
# This is the maximum GroupedGemm M size that we expect to run
|
||||
# the grouped_gemm with.
|
||||
MAX_M = compute_aligned_M(env.VLLM_FUSED_MOE_CHUNK_SIZE,
|
||||
num_topk,
|
||||
num_experts,
|
||||
block_m,
|
||||
expert_tokens_meta=None)
|
||||
# Distribute expert-ids evenly.
|
||||
MAX_BLOCKS = MAX_M // block_m
|
||||
expert_ids_block = torch.randint(low=0,
|
||||
high=num_experts,
|
||||
size=(MAX_BLOCKS, ),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
|
||||
|
||||
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
|
||||
|
||||
_, n, k = w.size()
|
||||
a1q = torch.empty((MAX_M, k), device=device).to(torch.float8_e4m3fn)
|
||||
a1q_scales = torch.empty((MAX_M, k // block_m),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
|
||||
|
||||
pbar = tqdm(total=MAX_BLOCKS,
|
||||
desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})")
|
||||
num_tokens = MAX_M
|
||||
while num_tokens > 0:
|
||||
m_grouped_fp8_gemm_nt_contiguous(
|
||||
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale),
|
||||
out[:num_tokens], expert_ids[:num_tokens])
|
||||
pbar.update(1)
|
||||
num_tokens = num_tokens - block_m
|
||||
|
||||
_warmup(w1, w1_scale)
|
||||
_warmup(w2, w2_scale)
|
||||
|
||||
|
||||
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(self):
|
||||
@@ -156,6 +217,20 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
):
|
||||
assert self.block_shape is not None
|
||||
assert a1q_scale is not None
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
|
||||
if not env.VLLM_SKIP_DEEP_GEMM_WARMUP:
|
||||
# DeepGemm JITs the grouped-gemm kernels. We don't want the JIT'ing
|
||||
# to happen during actual model-inference. The
|
||||
# `warmup_deepgemm_kernels` function is a `run_once` decorated
|
||||
# function that executes during the model profile run. This warmup
|
||||
# should create all the required JITs for the current model.
|
||||
warmup_deepgemm_gg_contiguous_kernels(w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
num_topk=topk_ids.size(1))
|
||||
|
||||
a1q = hidden_states
|
||||
_, N, K = w1.size()
|
||||
|
||||
Reference in New Issue
Block a user