Update fp4 quantize API (#21327)
Signed-off-by: Shu Wang <shuw@nvidia.com>
This commit is contained in:
@@ -181,12 +181,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
g2_alphas,
|
||||
]
|
||||
_ = flashinfer_cutlass_fused_moe(
|
||||
hidden_states,
|
||||
topk_ids.to(torch.int),
|
||||
topk_weights,
|
||||
input=hidden_states,
|
||||
token_selected_experts=topk_ids.to(torch.int),
|
||||
token_final_scales=topk_weights,
|
||||
# FlashInfer API requires weight to be long for nvfp4
|
||||
w1.view(torch.long),
|
||||
w2.view(torch.long),
|
||||
fc1_expert_weights=w1.view(torch.long),
|
||||
fc2_expert_weights=w2.view(torch.long),
|
||||
output_dtype=out_dtype,
|
||||
quant_scales=quant_scales,
|
||||
input_sf=a1q_scale,
|
||||
|
||||
@@ -11,7 +11,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
extract_required_args, moe_kernel_quantize_input)
|
||||
from vllm.utils.flashinfer import fp4_swizzle_blockscale
|
||||
from vllm.utils.flashinfer import block_scale_interleave
|
||||
|
||||
|
||||
def get_local_sizes(local_tokens):
|
||||
@@ -92,7 +92,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
dim=0,
|
||||
sizes=get_local_sizes(local_tokens))
|
||||
a1_m, a1_n = a1q.shape
|
||||
a1q_scale = fp4_swizzle_blockscale(a1q_scale, a1_m, a1_n * 2)
|
||||
a1q_scale = block_scale_interleave(a1q_scale)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
|
||||
@@ -69,8 +69,8 @@ flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
|
||||
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
|
||||
"cutlass_fused_moe")
|
||||
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
|
||||
fp4_swizzle_blockscale = _lazy_import_wrapper("flashinfer",
|
||||
"fp4_swizzle_blockscale")
|
||||
block_scale_interleave = _lazy_import_wrapper("flashinfer",
|
||||
"block_scale_interleave")
|
||||
|
||||
# Special case for autotune since it returns a context manager
|
||||
autotune = _lazy_import_wrapper(
|
||||
@@ -95,7 +95,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
|
||||
required_functions = [
|
||||
("flashinfer.fused_moe", "cutlass_fused_moe"),
|
||||
("flashinfer", "fp4_quantize"),
|
||||
("flashinfer", "fp4_swizzle_blockscale"),
|
||||
("flashinfer", "block_scale_interleave"),
|
||||
]
|
||||
|
||||
for module_name, attr_name in required_functions:
|
||||
@@ -110,7 +110,7 @@ __all__ = [
|
||||
"flashinfer_trtllm_fp8_block_scale_moe",
|
||||
"flashinfer_cutlass_fused_moe",
|
||||
"fp4_quantize",
|
||||
"fp4_swizzle_blockscale",
|
||||
"block_scale_interleave",
|
||||
"autotune",
|
||||
"has_flashinfer_moe",
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
|
||||
Reference in New Issue
Block a user