[UX] Speedup DeepGEMM warmup with heuristics (#25619)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
@@ -26,6 +26,55 @@ from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous
|
||||
|
||||
|
||||
def _generate_optimal_warmup_m_values(
|
||||
max_tokens: int, n: int, device: torch.device
|
||||
) -> list[int]:
|
||||
"""
|
||||
Generate M values that cover all possible DeepGEMM kernel configurations.
|
||||
Reference: https://github.com/deepseek-ai/DeepGEMM/blob/79f48ee15a82dd5fad5cd9beaa393c1f755e6b55/csrc/jit_kernels/heuristics/common.hpp
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum number of tokens to warmup for
|
||||
n: The actual N dimension from the weight tensor
|
||||
device: The torch device to get properties from.
|
||||
"""
|
||||
|
||||
def ceil_div(a: int, b: int) -> int:
|
||||
return (a + b - 1) // b
|
||||
|
||||
# DeepGEMM's possible block sizes
|
||||
block_ms = [64, 128, 256]
|
||||
block_ns = list(range(16, min(257, n + 1), 16))
|
||||
num_sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
|
||||
m_values = set()
|
||||
|
||||
# Always include small cases
|
||||
m_values.update([1, 2, 4] + [i for i in range(8, 65, 8)])
|
||||
|
||||
# Collect M values where different wave patterns occur
|
||||
for block_m in block_ms:
|
||||
for block_n in block_ns:
|
||||
if block_n > n:
|
||||
continue
|
||||
|
||||
# Add key M boundaries for this block combination
|
||||
for wave in range(1, 11): # Up to 10 waves
|
||||
# M where this block config transitions to next wave
|
||||
target_blocks = wave * num_sms
|
||||
m = target_blocks * block_m // ceil_div(n, block_n)
|
||||
if 1 <= m <= max_tokens:
|
||||
m_values.add(m)
|
||||
|
||||
# Add block_m boundaries
|
||||
for multiple in range(1, max_tokens // block_m + 1):
|
||||
m = multiple * block_m
|
||||
if m <= max_tokens:
|
||||
m_values.add(m)
|
||||
|
||||
return sorted(m_values)
|
||||
|
||||
|
||||
def _extract_data_from_linear_base_module(
|
||||
m: torch.nn.Module,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
|
||||
@@ -136,14 +185,27 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
|
||||
)
|
||||
out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)
|
||||
|
||||
pbar = tqdm(total=max_tokens, desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})")
|
||||
num_tokens = max_tokens
|
||||
while num_tokens > 0:
|
||||
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
|
||||
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
|
||||
if envs.VLLM_DEEP_GEMM_WARMUP == "relax":
|
||||
m_values = _generate_optimal_warmup_m_values(max_tokens, n, device)
|
||||
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [relaxed]"
|
||||
else:
|
||||
assert envs.VLLM_DEEP_GEMM_WARMUP == "full", (
|
||||
"Expected "
|
||||
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
|
||||
f"{envs.VLLM_DEEP_GEMM_WARMUP}"
|
||||
)
|
||||
m_values = list(range(1, max_tokens + 1))
|
||||
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [all tokens]"
|
||||
|
||||
pbar = tqdm(total=len(m_values), desc=desc)
|
||||
|
||||
for num_tokens in m_values:
|
||||
fp8_gemm_nt(
|
||||
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens]
|
||||
)
|
||||
pbar.update(1)
|
||||
num_tokens -= 1
|
||||
|
||||
FP8_GEMM_NT_WARMUP_CACHE.add(w.size())
|
||||
|
||||
@@ -195,12 +257,16 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
)
|
||||
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
|
||||
|
||||
# Generate M values in block_m increments (already optimized for MoE)
|
||||
m_values = list(range(block_m, MAX_M + 1, block_m))
|
||||
|
||||
pbar = tqdm(
|
||||
total=MAX_BLOCKS,
|
||||
desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})",
|
||||
total=len(m_values),
|
||||
desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()}) "
|
||||
f"[{len(m_values)} values, block_m={block_m}]",
|
||||
)
|
||||
num_tokens = MAX_M
|
||||
while num_tokens > 0:
|
||||
|
||||
for num_tokens in m_values:
|
||||
m_grouped_fp8_gemm_nt_contiguous(
|
||||
(a1q[:num_tokens], a1q_scales[:num_tokens]),
|
||||
(w, w_scale),
|
||||
@@ -208,7 +274,6 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
expert_ids[:num_tokens],
|
||||
)
|
||||
pbar.update(1)
|
||||
num_tokens = num_tokens - block_m
|
||||
|
||||
for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
|
||||
if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:
|
||||
|
||||
@@ -29,7 +29,7 @@ def kernel_warmup(worker: "Worker"):
|
||||
do_deep_gemm_warmup = (
|
||||
envs.VLLM_USE_DEEP_GEMM
|
||||
and is_deep_gemm_supported()
|
||||
and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP
|
||||
and envs.VLLM_DEEP_GEMM_WARMUP != "skip"
|
||||
)
|
||||
if do_deep_gemm_warmup:
|
||||
model = worker.get_model()
|
||||
|
||||
Reference in New Issue
Block a user