[ModelOpt] Introduce VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE env var to control blockscale tensor allocation (#18160)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
""" CUTLASS based Fused MoE kernels."""
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -271,8 +270,6 @@ def cutlass_moe_fp8(
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
MAX_TOKENS_PER_EXPERT = int(
|
||||
os.environ.get('VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT', '65536'))
|
||||
|
||||
|
||||
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
@@ -330,10 +327,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
|
||||
assert (topk_weights.shape[0] == m and topk_ids.shape[0]
|
||||
== m), ("topk must be provided for each row of a")
|
||||
assert (m <= MAX_TOKENS_PER_EXPERT), (
|
||||
f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})"
|
||||
f" for cutlass_moe_fp4, observed m = {m}. Use"
|
||||
f" VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT to set this value.")
|
||||
|
||||
out_dtype = a.dtype
|
||||
num_topk = topk_ids.shape[1]
|
||||
|
||||
@@ -362,8 +356,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
num_topk,
|
||||
expert_map=a_map,
|
||||
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
|
||||
expert_map=a_map)
|
||||
|
||||
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
|
||||
w1_blockscale, w1_alphas, problem_sizes1,
|
||||
@@ -378,12 +371,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
torch.ops._C.silu_and_mul(intermediate, c1)
|
||||
|
||||
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
|
||||
intermediate,
|
||||
a2_gscale,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
num_topk,
|
||||
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
|
||||
intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk)
|
||||
|
||||
c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale,
|
||||
w2_alphas, problem_sizes2, expert_offsets[:-1],
|
||||
|
||||
Reference in New Issue
Block a user