Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -8,12 +8,16 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts)
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize)
create_flashinfer_prepare_finalize,
)
logger = init_logger(__name__)
@@ -24,7 +28,6 @@ class FlashinferMoeBackend(Enum):
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
# with the necessary kernels is released.
@@ -44,13 +47,16 @@ def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
return x.reshape(-1, 2, x.shape[-2] // 2,
x.shape[-1]).flip(dims=[1]).reshape(x.shape)
return (
x.reshape(-1, 2, x.shape[-2] // 2, x.shape[-1]).flip(dims=[1]).reshape(x.shape)
)
def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor):
def rotate_flashinfer_fp8_moe_weights(
gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor
):
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
epilogue_tile_m = 128
num_experts = gemm1_weights.shape[0]
hidden_size = gemm1_weights.shape[-1]
@@ -60,13 +66,13 @@ def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor,
gemm1_weights_fp8_interleaved = []
for i in range(num_experts):
gemm1_weights_fp8_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_weights[i]))
reorder_rows_for_gated_act_gemm(gemm1_weights[i])
)
# Stack weights and scales for all experts
gemm1_weights_fp8_interleaved = torch.stack(
gemm1_weights_fp8_interleaved).reshape(num_experts,
2 * intermediate_size,
hidden_size)
gemm1_weights_fp8_interleaved = torch.stack(gemm1_weights_fp8_interleaved).reshape(
num_experts, 2 * intermediate_size, hidden_size
)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_fp8_shuffled = []
@@ -74,18 +80,21 @@ def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor,
for i in range(num_experts):
gemm1_weights_fp8_shuffled.append(
shuffle_matrix_a(
gemm1_weights_fp8_interleaved[i].view(torch.uint8),
epilogue_tile_m))
gemm1_weights_fp8_interleaved[i].view(torch.uint8), epilogue_tile_m
)
)
gemm2_weights_fp8_shuffled.append(
shuffle_matrix_a(gemm2_weights[i].view(torch.uint8),
epilogue_tile_m))
shuffle_matrix_a(gemm2_weights[i].view(torch.uint8), epilogue_tile_m)
)
# Stack weights for all experts
gemm1_weights.data = torch.stack(gemm1_weights_fp8_shuffled).view(
torch.float8_e4m3fn)
torch.float8_e4m3fn
)
gemm2_weights.data = torch.stack(gemm2_weights_fp8_shuffled).view(
torch.float8_e4m3fn)
torch.float8_e4m3fn
)
def apply_flashinfer_per_tensor_scale_fp8(
@@ -102,16 +111,22 @@ def apply_flashinfer_per_tensor_scale_fp8(
from flashinfer.fused_moe import RoutingMethodType
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
assert layer.output1_scales_scalar is not None, (
"Expected output1_scales_scalar to be initialized")
"Expected output1_scales_scalar to be initialized"
)
assert layer.output1_scales_scalar is not None, (
"Expected output1_scales_gate_scalar to be initialized")
"Expected output1_scales_gate_scalar to be initialized"
)
assert layer.output1_scales_scalar is not None, (
"Expected output2_scales_scalar to be initialized")
"Expected output2_scales_scalar to be initialized"
)
from vllm.model_executor.models.llama4 import Llama4MoE
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, (
"FusedMoE flashinfer kernels are only supported for Llama4"
)
return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8(
routing_logits=router_logits,
routing_bias=routing_bias,
@@ -140,37 +155,39 @@ def get_moe_scaling_factors(
activation_scale: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
output1_scales_scalar = gemm1_weights_scale * input_scale * (
1.0 / activation_scale)
output1_scales_scalar = gemm1_weights_scale * input_scale * (1.0 / activation_scale)
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
output2_scales_scalar = activation_scale * gemm2_weights_scale
return output1_scales_scalar, output1_scales_gate_scalar, \
output2_scales_scalar
return output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar
def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
output1_scales, output1_gate_scales, output2_scales = \
get_moe_scaling_factors(
layer.w13_input_scale, layer.w13_weight_scale,
layer.w2_input_scale, layer.w2_weight_scale
)
output1_scales, output1_gate_scales, output2_scales = get_moe_scaling_factors(
layer.w13_input_scale,
layer.w13_weight_scale,
layer.w2_input_scale,
layer.w2_weight_scale,
)
layer.register_parameter(
'output1_scales_scalar',
torch.nn.Parameter(output1_scales, requires_grad=False))
"output1_scales_scalar", torch.nn.Parameter(output1_scales, requires_grad=False)
)
layer.register_parameter(
'output1_scales_gate_scalar',
torch.nn.Parameter(output1_gate_scales, requires_grad=False))
"output1_scales_gate_scalar",
torch.nn.Parameter(output1_gate_scales, requires_grad=False),
)
layer.register_parameter(
'output2_scales_scalar',
torch.nn.Parameter(output2_scales, requires_grad=False))
"output2_scales_scalar", torch.nn.Parameter(output2_scales, requires_grad=False)
)
layer.register_parameter(
'w2_input_scale_inv',
torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False))
"w2_input_scale_inv",
torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False),
)
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe: Optional[FusedMoEConfig], ) -> mk.FusedMoEPrepareAndFinalize:
moe: Optional[FusedMoEConfig],
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
return create_flashinfer_prepare_finalize(use_dp)
@@ -193,8 +210,7 @@ def select_cutlass_fp8_gemm_impl(
tp_size=moe.moe_parallel_config.tp_size,
)
assert out_dtype is not None, (
"If moe config is None, out_dtype must be passed")
assert out_dtype is not None, "If moe config is None, out_dtype must be passed"
return FlashInferExperts(
out_dtype=out_dtype,
quant_config=quant_config,
@@ -217,9 +233,10 @@ def flashinfer_cutlass_moe_fp8(
fused_experts = mk.FusedMoEModularKernel(
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None),
select_cutlass_fp8_gemm_impl(moe=None,
quant_config=quant_config,
out_dtype=hidden_states.dtype))
select_cutlass_fp8_gemm_impl(
moe=None, quant_config=quant_config, out_dtype=hidden_states.dtype
),
)
return fused_experts(
hidden_states,
@@ -245,4 +262,5 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
allowed_backends = ["throughput", "latency"]
raise ValueError(
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
f" expected one of {allowed_backends}")
f" expected one of {allowed_backends}"
)