[UX] Reduce DeepGEMM warmup log output to single progress bar (#30903)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -10,7 +10,7 @@ import torch
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed.parallel_state import get_dp_group
|
from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
|
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
|
||||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod
|
||||||
@@ -175,7 +175,30 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
|||||||
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
|
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
|
||||||
|
|
||||||
|
|
||||||
def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: int):
|
def _get_fp8_gemm_nt_m_values(w: torch.Tensor, max_tokens: int) -> list[int]:
|
||||||
|
"""Get the M values to warmup for a given weight tensor."""
|
||||||
|
n, _ = w.size()
|
||||||
|
device = w.device
|
||||||
|
|
||||||
|
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
|
||||||
|
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
|
||||||
|
if envs.VLLM_DEEP_GEMM_WARMUP == "relax":
|
||||||
|
return _generate_optimal_warmup_m_values(max_tokens, n, device)
|
||||||
|
else:
|
||||||
|
assert envs.VLLM_DEEP_GEMM_WARMUP == "full", (
|
||||||
|
"Expected "
|
||||||
|
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
|
||||||
|
f"{envs.VLLM_DEEP_GEMM_WARMUP}"
|
||||||
|
)
|
||||||
|
return list(range(1, max_tokens + 1))
|
||||||
|
|
||||||
|
|
||||||
|
def _deepgemm_fp8_gemm_nt_warmup(
|
||||||
|
w: torch.Tensor,
|
||||||
|
ws: torch.Tensor,
|
||||||
|
max_tokens: int,
|
||||||
|
pbar: tqdm | None = None,
|
||||||
|
):
|
||||||
if w.size() in FP8_GEMM_NT_WARMUP_CACHE:
|
if w.size() in FP8_GEMM_NT_WARMUP_CACHE:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -189,27 +212,14 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
|
|||||||
)
|
)
|
||||||
out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)
|
out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
|
m_values = _get_fp8_gemm_nt_m_values(w, max_tokens)
|
||||||
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
|
|
||||||
if envs.VLLM_DEEP_GEMM_WARMUP == "relax":
|
|
||||||
m_values = _generate_optimal_warmup_m_values(max_tokens, n, device)
|
|
||||||
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [relaxed]"
|
|
||||||
else:
|
|
||||||
assert envs.VLLM_DEEP_GEMM_WARMUP == "full", (
|
|
||||||
"Expected "
|
|
||||||
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
|
|
||||||
f"{envs.VLLM_DEEP_GEMM_WARMUP}"
|
|
||||||
)
|
|
||||||
m_values = list(range(1, max_tokens + 1))
|
|
||||||
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [all tokens]"
|
|
||||||
|
|
||||||
pbar = tqdm(total=len(m_values), desc=desc)
|
|
||||||
|
|
||||||
for num_tokens in m_values:
|
for num_tokens in m_values:
|
||||||
fp8_gemm_nt(
|
fp8_gemm_nt(
|
||||||
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens]
|
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens]
|
||||||
)
|
)
|
||||||
pbar.update(1)
|
if pbar is not None:
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
FP8_GEMM_NT_WARMUP_CACHE.add(w.size())
|
FP8_GEMM_NT_WARMUP_CACHE.add(w.size())
|
||||||
|
|
||||||
@@ -217,20 +227,12 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
|
|||||||
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set()
|
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set()
|
||||||
|
|
||||||
|
|
||||||
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
def _get_grouped_gemm_params(
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
w1_scale: torch.Tensor,
|
|
||||||
w2_scale: torch.Tensor,
|
|
||||||
num_topk: int,
|
num_topk: int,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
):
|
) -> tuple[int, int, torch.Tensor]:
|
||||||
if (
|
|
||||||
w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
|
|
||||||
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
|
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
|
||||||
|
|
||||||
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
||||||
@@ -253,6 +255,27 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
|||||||
)
|
)
|
||||||
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
|
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
|
||||||
|
|
||||||
|
return MAX_M, block_m, expert_ids
|
||||||
|
|
||||||
|
|
||||||
|
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
num_topk: int,
|
||||||
|
max_tokens: int,
|
||||||
|
pbar: tqdm | None = None,
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
|
||||||
|
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
MAX_M, block_m, expert_ids = _get_grouped_gemm_params(w1, w2, num_topk, max_tokens)
|
||||||
|
device = w1.device
|
||||||
|
|
||||||
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
|
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
|
||||||
_, n, k = w.size()
|
_, n, k = w.size()
|
||||||
a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn)
|
a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn)
|
||||||
@@ -261,15 +284,8 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
|||||||
)
|
)
|
||||||
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
|
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
# Generate M values in block_m increments (already optimized for MoE)
|
|
||||||
m_values = list(range(block_m, MAX_M + 1, block_m))
|
m_values = list(range(block_m, MAX_M + 1, block_m))
|
||||||
|
|
||||||
pbar = tqdm(
|
|
||||||
total=len(m_values),
|
|
||||||
desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()}) "
|
|
||||||
f"[{len(m_values)} values, block_m={block_m}]",
|
|
||||||
)
|
|
||||||
|
|
||||||
for num_tokens in m_values:
|
for num_tokens in m_values:
|
||||||
m_grouped_fp8_gemm_nt_contiguous(
|
m_grouped_fp8_gemm_nt_contiguous(
|
||||||
(a1q[:num_tokens], a1q_scales[:num_tokens]),
|
(a1q[:num_tokens], a1q_scales[:num_tokens]),
|
||||||
@@ -277,7 +293,8 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
|||||||
out[:num_tokens],
|
out[:num_tokens],
|
||||||
expert_ids[:num_tokens],
|
expert_ids[:num_tokens],
|
||||||
)
|
)
|
||||||
pbar.update(1)
|
if pbar is not None:
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
|
for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
|
||||||
if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:
|
if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:
|
||||||
@@ -285,16 +302,18 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
|||||||
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size())
|
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size())
|
||||||
|
|
||||||
|
|
||||||
def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
|
def deepgemm_fp8_gemm_nt_warmup(
|
||||||
|
model: torch.nn.Module, max_tokens: int, pbar: tqdm | None = None
|
||||||
|
):
|
||||||
dg_modules = [m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)]
|
dg_modules = [m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)]
|
||||||
|
|
||||||
for dgm in dg_modules:
|
for dgm in dg_modules:
|
||||||
w, ws, _ = _extract_data_from_linear_base_module(dgm)
|
w, ws, _ = _extract_data_from_linear_base_module(dgm)
|
||||||
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens)
|
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens, pbar=pbar)
|
||||||
|
|
||||||
|
|
||||||
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||||
model: torch.nn.Module, max_tokens: int
|
model: torch.nn.Module, max_tokens: int, pbar: tqdm | None = None
|
||||||
):
|
):
|
||||||
dg_modules = [
|
dg_modules = [
|
||||||
m for m in model.modules() if _fused_moe_grouped_gemm_may_use_deep_gemm(m)
|
m for m in model.modules() if _fused_moe_grouped_gemm_may_use_deep_gemm(m)
|
||||||
@@ -305,10 +324,48 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
|||||||
dgm
|
dgm
|
||||||
)
|
)
|
||||||
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||||
w13, w2, w13_scale, w2_scale, num_topk, max_tokens
|
w13, w2, w13_scale, w2_scale, num_topk, max_tokens, pbar=pbar
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_warmup_iterations(model: torch.nn.Module, max_tokens: int) -> int:
|
||||||
|
seen_fp8_sizes: set[torch.Size] = set(FP8_GEMM_NT_WARMUP_CACHE)
|
||||||
|
seen_grouped_sizes: set[torch.Size] = set(
|
||||||
|
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
|
||||||
|
)
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
for m in model.modules():
|
||||||
|
if _fp8_linear_may_use_deep_gemm(m):
|
||||||
|
w, _, _ = _extract_data_from_linear_base_module(m)
|
||||||
|
if w.size() not in seen_fp8_sizes:
|
||||||
|
total += len(_get_fp8_gemm_nt_m_values(w, max_tokens))
|
||||||
|
seen_fp8_sizes.add(w.size())
|
||||||
|
elif _fused_moe_grouped_gemm_may_use_deep_gemm(m):
|
||||||
|
w13, _, w2, _, num_topk = _extract_data_from_fused_moe_module(m)
|
||||||
|
if w13.size() in seen_grouped_sizes and w2.size() in seen_grouped_sizes:
|
||||||
|
continue
|
||||||
|
MAX_M, block_m, _ = _get_grouped_gemm_params(w13, w2, num_topk, max_tokens)
|
||||||
|
n_values = (MAX_M - block_m) // block_m + 1
|
||||||
|
if w13.size() not in seen_grouped_sizes:
|
||||||
|
total += n_values
|
||||||
|
seen_grouped_sizes.add(w13.size())
|
||||||
|
if w2.size() not in seen_grouped_sizes:
|
||||||
|
total += n_values
|
||||||
|
seen_grouped_sizes.add(w2.size())
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
|
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
|
||||||
deepgemm_fp8_gemm_nt_warmup(model, max_tokens)
|
total = _count_warmup_iterations(model, max_tokens)
|
||||||
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens)
|
if total == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Only show progress bar on rank 0 to avoid cluttered output
|
||||||
|
if is_global_first_rank():
|
||||||
|
with tqdm(total=total, desc="DeepGEMM warmup") as pbar:
|
||||||
|
deepgemm_fp8_gemm_nt_warmup(model, max_tokens, pbar)
|
||||||
|
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens, pbar)
|
||||||
|
else:
|
||||||
|
deepgemm_fp8_gemm_nt_warmup(model, max_tokens, None)
|
||||||
|
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens, None)
|
||||||
|
|||||||
Reference in New Issue
Block a user