[Model Bash][DSR1] Add selective dynamic shape marking for CustomOp (#34900)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Vadim Gimpelson
2026-02-22 04:28:01 +04:00
committed by GitHub
parent a4047d4ea9
commit 74d90b1ce4
2 changed files with 46 additions and 7 deletions

View File

@@ -1,5 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import inspect
import torch
import torch.nn as nn
@@ -205,9 +208,9 @@ class CustomOp(nn.Module):
NOTE: this does not enable fusion across ops, so opaque custom ops
should still be unwrapped wherever possible.
"""
# Do not compile if compilation disabled
from vllm.config.compilation import CompilationMode
# Do not compile if compilation disabled
if not enable:
return fn
@@ -220,14 +223,42 @@ class CustomOp(nn.Module):
if compilation_config.backend == "eager":
return fn
compile_options = maybe_disable_graph_partition(
current_platform.simple_compile_backend
)
backend = current_platform.simple_compile_backend
dynamic_arg_dims = getattr(self.__class__, "_dynamic_arg_dims", None)
if dynamic_arg_dims is not None:
compiled_fn = torch.compile(
fn,
dynamic=False,
backend=backend,
options=compile_options,
)
sig = inspect.signature(fn)
@functools.wraps(fn)
def wrapper(*args, **kwargs):
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
for name, dims in dynamic_arg_dims.items():
arg = bound.arguments.get(name)
if arg is not None and isinstance(arg, torch.Tensor):
dims_list = [dims] if isinstance(dims, int) else dims
for d in dims_list:
real_d = arg.ndim + d if d < 0 else d
torch._dynamo.mark_dynamic(arg, real_d)
return compiled_fn(*args, **kwargs)
return wrapper
# dynamic=True to avoid recompilations
return torch.compile(
fn,
dynamic=True,
backend=current_platform.simple_compile_backend,
options=maybe_disable_graph_partition(
current_platform.simple_compile_backend
),
backend=backend,
options=compile_options,
)
@classmethod
@@ -267,10 +298,15 @@ class CustomOp(nn.Module):
# Decorator to register custom ops.
@classmethod
def register(cls, name: str):
def register(
cls,
name: str,
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
):
def decorator(op_cls):
assert name not in op_registry, f"Duplicate op name: {name}"
op_cls.name = name
op_cls._dynamic_arg_dims = dynamic_arg_dims
op_registry[name] = op_cls
return op_cls

View File

@@ -950,7 +950,10 @@ def dynamic_per_batched_tensor_quant(
logger = init_logger(__name__)
@CustomOp.register("mla_decode_concat_quant_fp8")
@CustomOp.register(
"mla_decode_concat_quant_fp8",
dynamic_arg_dims={"decode_ql_nope": 0, "decode_q_pe": 0},
)
class _DecodeConcatQuantFP8(QuantFP8):
"""
QuantFP8 variant that concatenates decode_ql_nope and decode_q_pe before