[Model Bash][DSR1] Add selective dynamic shape marking for CustomOp (#34900)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -205,9 +208,9 @@ class CustomOp(nn.Module):
|
||||
NOTE: this does not enable fusion across ops, so opaque custom ops
|
||||
should still be unwrapped wherever possible.
|
||||
"""
|
||||
# Do not compile if compilation disabled
|
||||
from vllm.config.compilation import CompilationMode
|
||||
|
||||
# Do not compile if compilation disabled
|
||||
if not enable:
|
||||
return fn
|
||||
|
||||
@@ -220,14 +223,42 @@ class CustomOp(nn.Module):
|
||||
if compilation_config.backend == "eager":
|
||||
return fn
|
||||
|
||||
compile_options = maybe_disable_graph_partition(
|
||||
current_platform.simple_compile_backend
|
||||
)
|
||||
backend = current_platform.simple_compile_backend
|
||||
|
||||
dynamic_arg_dims = getattr(self.__class__, "_dynamic_arg_dims", None)
|
||||
if dynamic_arg_dims is not None:
|
||||
compiled_fn = torch.compile(
|
||||
fn,
|
||||
dynamic=False,
|
||||
backend=backend,
|
||||
options=compile_options,
|
||||
)
|
||||
sig = inspect.signature(fn)
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
bound = sig.bind(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
for name, dims in dynamic_arg_dims.items():
|
||||
arg = bound.arguments.get(name)
|
||||
if arg is not None and isinstance(arg, torch.Tensor):
|
||||
dims_list = [dims] if isinstance(dims, int) else dims
|
||||
for d in dims_list:
|
||||
real_d = arg.ndim + d if d < 0 else d
|
||||
torch._dynamo.mark_dynamic(arg, real_d)
|
||||
return compiled_fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
# dynamic=True to avoid recompilations
|
||||
return torch.compile(
|
||||
fn,
|
||||
dynamic=True,
|
||||
backend=current_platform.simple_compile_backend,
|
||||
options=maybe_disable_graph_partition(
|
||||
current_platform.simple_compile_backend
|
||||
),
|
||||
backend=backend,
|
||||
options=compile_options,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -267,10 +298,15 @@ class CustomOp(nn.Module):
|
||||
|
||||
# Decorator to register custom ops.
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
def register(
|
||||
cls,
|
||||
name: str,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
|
||||
):
|
||||
def decorator(op_cls):
|
||||
assert name not in op_registry, f"Duplicate op name: {name}"
|
||||
op_cls.name = name
|
||||
op_cls._dynamic_arg_dims = dynamic_arg_dims
|
||||
op_registry[name] = op_cls
|
||||
return op_cls
|
||||
|
||||
|
||||
@@ -950,7 +950,10 @@ def dynamic_per_batched_tensor_quant(
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@CustomOp.register("mla_decode_concat_quant_fp8")
|
||||
@CustomOp.register(
|
||||
"mla_decode_concat_quant_fp8",
|
||||
dynamic_arg_dims={"decode_ql_nope": 0, "decode_q_pe": 0},
|
||||
)
|
||||
class _DecodeConcatQuantFP8(QuantFP8):
|
||||
"""
|
||||
QuantFP8 variant that concatenates decode_ql_nope and decode_q_pe before
|
||||
|
||||
Reference in New Issue
Block a user