[torch.compile] Add support for non-contiguous fused RMSNorm + group quant (#36551)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -44,6 +44,20 @@ ROCM_AITER_UNIFIED_ATTN = pytest.param(
|
||||
),
|
||||
)
|
||||
|
||||
FLASHINFER_MLA_ATTN = pytest.param(
|
||||
AttentionBackendCase(backend=AttentionBackendEnum.FLASHINFER_MLA),
|
||||
id="FLASHINFER_MLA",
|
||||
marks=pytest.mark.skipif(
|
||||
not is_blackwell() or not has_flashinfer(),
|
||||
reason="FI backend requires Blackwell and FlashInfer",
|
||||
),
|
||||
)
|
||||
|
||||
TRITON_MLA_ATTN = pytest.param(
|
||||
AttentionBackendCase(backend=AttentionBackendEnum.TRITON_MLA),
|
||||
id="TRITON_MLA",
|
||||
)
|
||||
|
||||
# Models
|
||||
llama3_8b = ModelFusionInfo(
|
||||
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||
@@ -126,3 +140,25 @@ qwen3_a3b_fp8 = ModelFusionInfo(
|
||||
async_tp=n_layers * 2,
|
||||
),
|
||||
)
|
||||
|
||||
deepseek_v3_fp8 = ModelFusionInfo(
|
||||
model_name="deepseek-ai/DeepSeek-V3",
|
||||
matches=lambda n_layers: Matches(
|
||||
# 3 per dense layer (first 3):
|
||||
# - input_rms + qkv_proj
|
||||
# - q_a_layernorm + q_b_proj (inside MLA wrapper)
|
||||
# - post_attn_layernorm + MLP
|
||||
# 2 per MoE layer (remaining) due to MoE wrapping
|
||||
rms_quant_fusion=n_layers * 2 + min(3, n_layers), # add for 3 dense layers
|
||||
# TODO silu+block quant
|
||||
# act_quant_fusion=min(3, n_layers), # dense layers only
|
||||
act_quant_fusion=0,
|
||||
# MLA attn + quant not supported yet:
|
||||
# https://github.com/vllm-project/vllm/issues/35792
|
||||
attn_quant_fusion=0,
|
||||
ar_rms_fusion=n_layers * 2 + 1,
|
||||
# TODO
|
||||
# sequence_parallel= n_layers * 2 + 1,
|
||||
# async_tp=n_layers * 2,
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user