Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -8,12 +8,16 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts)
|
||||
FlashInferExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
create_flashinfer_prepare_finalize)
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -24,7 +28,6 @@ class FlashinferMoeBackend(Enum):
|
||||
|
||||
|
||||
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
|
||||
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
|
||||
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
|
||||
# with the necessary kernels is released.
|
||||
@@ -44,13 +47,16 @@ def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
|
||||
|
||||
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.reshape(-1, 2, x.shape[-2] // 2,
|
||||
x.shape[-1]).flip(dims=[1]).reshape(x.shape)
|
||||
return (
|
||||
x.reshape(-1, 2, x.shape[-2] // 2, x.shape[-1]).flip(dims=[1]).reshape(x.shape)
|
||||
)
|
||||
|
||||
|
||||
def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor):
|
||||
def rotate_flashinfer_fp8_moe_weights(
|
||||
gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor
|
||||
):
|
||||
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
|
||||
|
||||
epilogue_tile_m = 128
|
||||
num_experts = gemm1_weights.shape[0]
|
||||
hidden_size = gemm1_weights.shape[-1]
|
||||
@@ -60,13 +66,13 @@ def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor,
|
||||
gemm1_weights_fp8_interleaved = []
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_fp8_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(gemm1_weights[i]))
|
||||
reorder_rows_for_gated_act_gemm(gemm1_weights[i])
|
||||
)
|
||||
|
||||
# Stack weights and scales for all experts
|
||||
gemm1_weights_fp8_interleaved = torch.stack(
|
||||
gemm1_weights_fp8_interleaved).reshape(num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size)
|
||||
gemm1_weights_fp8_interleaved = torch.stack(gemm1_weights_fp8_interleaved).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size
|
||||
)
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_fp8_shuffled = []
|
||||
@@ -74,18 +80,21 @@ def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor,
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_fp8_shuffled.append(
|
||||
shuffle_matrix_a(
|
||||
gemm1_weights_fp8_interleaved[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm1_weights_fp8_interleaved[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
|
||||
gemm2_weights_fp8_shuffled.append(
|
||||
shuffle_matrix_a(gemm2_weights[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
shuffle_matrix_a(gemm2_weights[i].view(torch.uint8), epilogue_tile_m)
|
||||
)
|
||||
|
||||
# Stack weights for all experts
|
||||
gemm1_weights.data = torch.stack(gemm1_weights_fp8_shuffled).view(
|
||||
torch.float8_e4m3fn)
|
||||
torch.float8_e4m3fn
|
||||
)
|
||||
gemm2_weights.data = torch.stack(gemm2_weights_fp8_shuffled).view(
|
||||
torch.float8_e4m3fn)
|
||||
torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
|
||||
def apply_flashinfer_per_tensor_scale_fp8(
|
||||
@@ -102,16 +111,22 @@ def apply_flashinfer_per_tensor_scale_fp8(
|
||||
from flashinfer.fused_moe import RoutingMethodType
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
|
||||
assert layer.output1_scales_scalar is not None, (
|
||||
"Expected output1_scales_scalar to be initialized")
|
||||
"Expected output1_scales_scalar to be initialized"
|
||||
)
|
||||
assert layer.output1_scales_scalar is not None, (
|
||||
"Expected output1_scales_gate_scalar to be initialized")
|
||||
"Expected output1_scales_gate_scalar to be initialized"
|
||||
)
|
||||
assert layer.output1_scales_scalar is not None, (
|
||||
"Expected output2_scales_scalar to be initialized")
|
||||
"Expected output2_scales_scalar to be initialized"
|
||||
)
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
|
||||
|
||||
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, (
|
||||
"FusedMoE flashinfer kernels are only supported for Llama4"
|
||||
)
|
||||
return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
@@ -140,37 +155,39 @@ def get_moe_scaling_factors(
|
||||
activation_scale: torch.Tensor,
|
||||
gemm2_weights_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
output1_scales_scalar = gemm1_weights_scale * input_scale * (
|
||||
1.0 / activation_scale)
|
||||
output1_scales_scalar = gemm1_weights_scale * input_scale * (1.0 / activation_scale)
|
||||
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
|
||||
output2_scales_scalar = activation_scale * gemm2_weights_scale
|
||||
|
||||
return output1_scales_scalar, output1_scales_gate_scalar, \
|
||||
output2_scales_scalar
|
||||
return output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar
|
||||
|
||||
|
||||
def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
|
||||
output1_scales, output1_gate_scales, output2_scales = \
|
||||
get_moe_scaling_factors(
|
||||
layer.w13_input_scale, layer.w13_weight_scale,
|
||||
layer.w2_input_scale, layer.w2_weight_scale
|
||||
)
|
||||
output1_scales, output1_gate_scales, output2_scales = get_moe_scaling_factors(
|
||||
layer.w13_input_scale,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_input_scale,
|
||||
layer.w2_weight_scale,
|
||||
)
|
||||
layer.register_parameter(
|
||||
'output1_scales_scalar',
|
||||
torch.nn.Parameter(output1_scales, requires_grad=False))
|
||||
"output1_scales_scalar", torch.nn.Parameter(output1_scales, requires_grad=False)
|
||||
)
|
||||
layer.register_parameter(
|
||||
'output1_scales_gate_scalar',
|
||||
torch.nn.Parameter(output1_gate_scales, requires_grad=False))
|
||||
"output1_scales_gate_scalar",
|
||||
torch.nn.Parameter(output1_gate_scales, requires_grad=False),
|
||||
)
|
||||
layer.register_parameter(
|
||||
'output2_scales_scalar',
|
||||
torch.nn.Parameter(output2_scales, requires_grad=False))
|
||||
"output2_scales_scalar", torch.nn.Parameter(output2_scales, requires_grad=False)
|
||||
)
|
||||
layer.register_parameter(
|
||||
'w2_input_scale_inv',
|
||||
torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False))
|
||||
"w2_input_scale_inv",
|
||||
torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
|
||||
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe: Optional[FusedMoEConfig], ) -> mk.FusedMoEPrepareAndFinalize:
|
||||
moe: Optional[FusedMoEConfig],
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
|
||||
return create_flashinfer_prepare_finalize(use_dp)
|
||||
@@ -193,8 +210,7 @@ def select_cutlass_fp8_gemm_impl(
|
||||
tp_size=moe.moe_parallel_config.tp_size,
|
||||
)
|
||||
|
||||
assert out_dtype is not None, (
|
||||
"If moe config is None, out_dtype must be passed")
|
||||
assert out_dtype is not None, "If moe config is None, out_dtype must be passed"
|
||||
return FlashInferExperts(
|
||||
out_dtype=out_dtype,
|
||||
quant_config=quant_config,
|
||||
@@ -217,9 +233,10 @@ def flashinfer_cutlass_moe_fp8(
|
||||
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None),
|
||||
select_cutlass_fp8_gemm_impl(moe=None,
|
||||
quant_config=quant_config,
|
||||
out_dtype=hidden_states.dtype))
|
||||
select_cutlass_fp8_gemm_impl(
|
||||
moe=None, quant_config=quant_config, out_dtype=hidden_states.dtype
|
||||
),
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states,
|
||||
@@ -245,4 +262,5 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
allowed_backends = ["throughput", "latency"]
|
||||
raise ValueError(
|
||||
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
|
||||
f" expected one of {allowed_backends}")
|
||||
f" expected one of {allowed_backends}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user