Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -23,8 +23,7 @@ class LazyInitPass(InductorPass):
and then immediately invoke it.
"""
def __init__(self, pass_cls: type[VllmInductorPass],
vllm_config: VllmConfig):
def __init__(self, pass_cls: type[VllmInductorPass], vllm_config: VllmConfig):
self.pass_cls = pass_cls
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
@@ -45,20 +44,18 @@ class TestBackend:
Inductor config is default-initialized from VllmConfig.CompilationConfig.
"""
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
None]]):
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]):
self.custom_passes = list(passes)
compile_config = get_current_vllm_config().compilation_config
self.inductor_config = compile_config.inductor_compile_config
self.inductor_config['force_disable_caches'] = True
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
self.inductor_config["force_disable_caches"] = True
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
def __call__(self, graph: fx.GraphModule, example_inputs):
self.graph_pre_compile = deepcopy(graph)
from torch._inductor.compile_fx import compile_fx
return compile_fx(graph,
example_inputs,
config_patches=self.inductor_config)
return compile_fx(graph, example_inputs, config_patches=self.inductor_config)
@with_pattern_match_debug
def post_pass(self, graph: fx.Graph):
@@ -82,8 +79,7 @@ class TestBackend:
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
if fully_replaced:
assert num_post == 0, \
f"Unexpected op {op.name()} in post-pass graph"
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
def check_after_ops(self, ops: Sequence[OpOverload]):
for op in ops:

View File

@@ -38,8 +38,8 @@ test_params_full_cudagraph = []
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
for mla_backend in MLA_backends:
test_params_full_cudagraph.append(
pytest.param(
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))
pytest.param(("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]))
)
# Qwen/Qwen2-1.5B-Instruct with other backends
other_backend_configs = [
@@ -47,7 +47,8 @@ other_backend_configs = [
]
for backend_config in other_backend_configs:
test_params_full_cudagraph.append(
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config))
)
@pytest.fixture(scope="class")
@@ -55,8 +56,10 @@ def llm_pair(request):
model, backend_config = request.param
# Dynamically skip test if GPU capability is not met
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
!= current_platform.get_device_capability():
if (
backend_config.specific_gpu_arch
and backend_config.specific_gpu_arch != current_platform.get_device_capability()
):
if backend_config.specific_gpu_arch == (9, 0):
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
elif backend_config.specific_gpu_arch == (10, 0):
@@ -76,8 +79,7 @@ def llm_pair(request):
trust_remote_code=True,
max_model_len=1024,
max_num_seqs=128,
compilation_config=\
CompilationConfig(**backend_config.comp_config),
compilation_config=CompilationConfig(**backend_config.comp_config),
generation_config="vllm",
seed=42,
)
@@ -113,20 +115,22 @@ class TestFullCUDAGraph:
meaning there would be multiple LLM instances hogging memory simultaneously.
"""
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
(1, 10),
(7, 10),
(16, 10),
(25, 10),
(32, 10),
(45, 10),
(64, 10),
(123, 10),
(8, 5),
(8, 30),
])
def test_full_cudagraph(self, batch_size, max_tokens,
llm_pair: tuple[LLM, LLM]):
@pytest.mark.parametrize(
("batch_size", "max_tokens"),
[
(1, 10),
(7, 10),
(16, 10),
(25, 10),
(32, 10),
(45, 10),
(64, 10),
(123, 10),
(8, 5),
(8, 30),
],
)
def test_full_cudagraph(self, batch_size, max_tokens, llm_pair: tuple[LLM, LLM]):
"""
Test various batch sizes and max_tokens to ensure that the
full cudagraph compilation works for padded cases too.
@@ -137,26 +141,34 @@ class TestFullCUDAGraph:
prompts = ["the quick brown fox"] * batch_size
# Use purely greedy decoding to avoid top-p truncation sensitivity
# that can amplify tiny numeric differences across runtimes.
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=1.0)
sampling_params = SamplingParams(
temperature=0.0, max_tokens=max_tokens, top_p=1.0
)
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
# Check that all responses are the same
for piecewise_res, full_res in zip(piecewise_responses,
full_responses):
assert piecewise_res.outputs[0].text.lower() == \
full_res.outputs[0].text.lower()
for piecewise_res, full_res in zip(piecewise_responses, full_responses):
assert (
piecewise_res.outputs[0].text.lower()
== full_res.outputs[0].text.lower()
)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_full_cudagraph_with_invalid_backend():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
# Flex_Attention is not supported with full cuda graph
}), pytest.raises(RuntimeError):
LLM(model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(cudagraph_mode="FULL"))
with (
temporary_environ(
{
"VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION",
# Flex_Attention is not supported with full cuda graph
}
),
pytest.raises(RuntimeError),
):
LLM(
model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
)

View File

@@ -10,10 +10,14 @@ from torch import nn
from vllm.compilation.backends import set_model_tag
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import (ignore_torch_compile,
support_torch_compile)
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
from vllm.config import (
CompilationConfig,
CompilationLevel,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
# This import automatically registers `torch.ops.silly.attention`
@@ -27,12 +31,7 @@ RANDOM_SEED = 0
@support_torch_compile
class ParentModel(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -40,7 +39,6 @@ class ParentModel(nn.Module):
class Attention(nn.Module):
def __init__(self, mlp_size: int, hidden_size: int) -> None:
super().__init__()
self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False)
@@ -51,17 +49,21 @@ class Attention(nn.Module):
nn.init.xavier_normal_(
self.pre_attn.weight.data,
generator=torch.Generator().manual_seed(RANDOM_SEED),
gain=0.001)
gain=0.001,
)
nn.init.xavier_normal_(
self.post_attn.weight.data,
generator=torch.Generator().manual_seed(RANDOM_SEED),
gain=0.001)
gain=0.001,
)
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
x_f32 = x.float()
return (x_f32 * torch.rsqrt(
torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) *
self.rms_norm_weight).to(x.dtype)
return (
x_f32
* torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6)
* self.rms_norm_weight
).to(x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pre_attn(x)
@@ -76,14 +78,15 @@ class Attention(nn.Module):
@support_torch_compile
class CompiledAttention(nn.Module):
def __init__(self,
*,
mlp_size: int,
hidden_size: int,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
def __init__(
self,
*,
mlp_size: int,
hidden_size: int,
vllm_config: VllmConfig,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
self.attn = Attention(mlp_size, hidden_size)
@@ -93,21 +96,21 @@ class CompiledAttention(nn.Module):
@support_torch_compile
class CompiledAttentionTwo(CompiledAttention):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.attn(x) + x
@ignore_torch_compile
class SimpleModelWithTwoGraphs(ParentModel):
def __init__(self,
*,
mlp_size: int,
hidden_size: int,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
def __init__(
self,
*,
mlp_size: int,
hidden_size: int,
vllm_config: VllmConfig,
prefix: str = "",
**kwargs,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Test will fail without set_model_tag here with error:
# "ValueError: too many values to unpack (expected 3)"
@@ -142,32 +145,45 @@ class SimpleModelWithTwoGraphs(ParentModel):
@torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor,
cudagraph_runtime_mode: CUDAGraphMode):
def run_model(
vllm_config: VllmConfig,
model: nn.Module,
inputs: torch.Tensor,
cudagraph_runtime_mode: CUDAGraphMode,
):
with set_forward_context({}, vllm_config=vllm_config):
# warmup for the model with cudagraph_mode NONE
model(inputs)
# simulate cudagraphs capturing
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
model(inputs[:2])
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=1, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=1,
),
):
model(inputs[:1])
# simulate cudagraphs replay
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
output = model(inputs[:2])
output = output.cpu()
@@ -178,82 +194,104 @@ def test_multi_graph_piecewise_compile_outputs_equal():
outputs = []
# piecewise compile
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
))
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix='').eval().cuda()
model = (
SimpleModelWithTwoGraphs(
mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix="",
)
.eval()
.cuda()
)
# Pre-allocate memory for CUDAGraph which expects
# static tensor addresses
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
with compilation_counter.expect(
num_graphs_seen=2, # two graphs for the model
num_piecewise_graphs_seen=6,
# attn_one, attn_two each has 3 piecewise graphs
# (pre attn, post attn, silly_attention) each
num_piecewise_capturable_graphs_seen=4,
# attn_one, attn_two has pre attn and post attn each, total=4
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_graphs_seen=2, # two graphs for the model
num_piecewise_graphs_seen=6,
# attn_one, attn_two each has 3 piecewise graphs
# (pre attn, post attn, silly_attention) each
num_piecewise_capturable_graphs_seen=4,
# attn_one, attn_two has pre attn and post attn each, total=4
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
outputs.append(
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
# no compile or cudagraph
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.NO_COMPILATION, ))
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.NO_COMPILATION,
)
)
cudagraph_runtime_mode = CUDAGraphMode.NONE
with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix='').eval().cuda()
model = (
SimpleModelWithTwoGraphs(
mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix="",
)
.eval()
.cuda()
)
with compilation_counter.expect(
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_captured=0,
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_captured=0,
):
outputs.append(
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
# piecewise compile without CUDA graph
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=False,
splitting_ops=["silly.attention"],
))
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=False,
splitting_ops=["silly.attention"],
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix='').eval().cuda()
model = (
SimpleModelWithTwoGraphs(
mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix="",
)
.eval()
.cuda()
)
with compilation_counter.expect(
num_graphs_seen=2,
num_piecewise_graphs_seen=6,
num_piecewise_capturable_graphs_seen=4,
num_backend_compilations=4,
num_cudagraph_captured=0, # no cudagraph captured
num_graphs_seen=2,
num_piecewise_graphs_seen=6,
num_piecewise_capturable_graphs_seen=4,
num_backend_compilations=4,
num_cudagraph_captured=0, # no cudagraph captured
):
outputs.append(
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
# Generally don't expect outputs with and without inductor
# to be bitwise equivalent

View File

@@ -11,8 +11,13 @@ from torch import nn
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.config import (
CompilationConfig,
CompilationLevel,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer
@@ -23,12 +28,7 @@ from ..silly_attention import get_global_counter, reset_global_counter
@support_torch_compile
class SillyModel(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -60,53 +60,65 @@ def _run_simple_model(
expected_num_backend_compilations,
expected_num_cudagraph_captured,
):
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
use_inductor=use_inductor,
splitting_ops=splitting_ops,
use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2],
))
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
use_inductor=use_inductor,
splitting_ops=splitting_ops,
use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2],
)
)
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')
model = SillyModel(vllm_config=vllm_config, prefix="")
inputs = torch.randn(100).cuda()
with compilation_counter.expect(
with (
compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen=
expected_num_piecewise_capturable_graphs_seen,
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=expected_num_cudagraph_captured,
), set_forward_context(None,
vllm_config=vllm_config): # background context
),
set_forward_context(None, vllm_config=vllm_config),
): # background context
# warm up with background context
model(inputs)
# capturing/replaying should under context of cudagraph dispatching
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )):
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
model(torch.randn(2).cuda())
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=1, )):
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(
num_tokens=1,
),
):
model(torch.randn(1).cuda())
input = torch.zeros(2).cuda()
reset_global_counter()
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )):
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
output = model(input)
assert get_global_counter() == 2
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
@@ -122,10 +134,8 @@ def test_simple_piecewise_compile(use_inductor):
use_inductor=use_inductor,
expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1
expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
expected_num_backend_compilations=
3, # num_piecewise_capturable_graphs_seen
expected_num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
expected_num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
expected_num_cudagraph_captured=6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
)
@@ -134,8 +144,7 @@ def test_simple_piecewise_compile(use_inductor):
def test_simple_inductor_graph_partition(splitting_ops):
assert VLLM_USE_V1
if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
_run_simple_model(
# inductor graph partition automatically resets splitting_ops
@@ -143,13 +152,9 @@ def test_simple_inductor_graph_partition(splitting_ops):
splitting_ops=splitting_ops,
use_inductor_graph_partition=True,
use_inductor=True,
expected_num_piecewise_graphs_seen=
1, # since not splitting at fx graph level
expected_num_piecewise_capturable_graphs_seen=
1, # since not splitting at fx graph level
expected_num_backend_compilations=
1, # since not splitting at fx graph level
expected_num_cudagraph_captured=
6, # inductor graph partition still captures 6
expected_num_piecewise_graphs_seen=1, # since not splitting at fx graph level
expected_num_piecewise_capturable_graphs_seen=1, # since not splitting at fx graph level
expected_num_backend_compilations=1, # since not splitting at fx graph level
expected_num_cudagraph_captured=6, # inductor graph partition still captures 6
# graph, same as fx graph partition.
)

View File

@@ -8,6 +8,7 @@ This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
"""
from dataclasses import dataclass
from typing import Any, Optional
@@ -17,8 +18,13 @@ from torch import nn
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.config import (
CompilationConfig,
CompilationLevel,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
# This import automatically registers `torch.ops.silly.attention`
@@ -43,15 +49,14 @@ class LlamaConfig:
factors.append((k, v))
factors.sort()
import hashlib
return hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
def __post_init__(self):
assert self.mlp_size >= self.hidden_size
class LlamaMLP(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.gate_up_projection = nn.Linear(
@@ -66,31 +71,31 @@ class LlamaMLP(nn.Module):
)
if config.tractable_init:
nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size])
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:])
nn.init.eye_(self.gate_up_projection.weight.data[: config.mlp_size])
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size :])
nn.init.eye_(self.down_projection.weight.data)
else:
nn.init.xavier_normal_(self.gate_up_projection.weight.data,
generator=torch.Generator().manual_seed(
config.random_seed),
gain=0.001)
nn.init.xavier_normal_(self.down_projection.weight.data,
generator=torch.Generator().manual_seed(
config.random_seed),
gain=0.001)
nn.init.xavier_normal_(
self.gate_up_projection.weight.data,
generator=torch.Generator().manual_seed(config.random_seed),
gain=0.001,
)
nn.init.xavier_normal_(
self.down_projection.weight.data,
generator=torch.Generator().manual_seed(config.random_seed),
gain=0.001,
)
def forward(self, x):
# for tractable_init and positive input, this is
# essentially an elementwise-square
x = self.gate_up_projection(x)
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
x[:, x.size(1) // 2:])
x = x[:, : x.size(1) // 2] * torch.nn.functional.relu(x[:, x.size(1) // 2 :])
x = self.down_projection(x)
return x
class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.qkv_projection = nn.Linear(
@@ -106,21 +111,25 @@ class LlamaAttention(nn.Module):
)
if config.tractable_init:
nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size])
nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 *
config.hidden_size])
nn.init.eye_(self.qkv_projection.weight.data[2 *
config.hidden_size:])
nn.init.eye_(self.qkv_projection.weight.data[: config.hidden_size])
nn.init.eye_(
self.qkv_projection.weight.data[
config.hidden_size : 2 * config.hidden_size
]
)
nn.init.eye_(self.qkv_projection.weight.data[2 * config.hidden_size :])
nn.init.eye_(self.output_projection.weight.data)
else:
nn.init.xavier_normal_(self.qkv_projection.weight.data,
generator=torch.Generator().manual_seed(
config.random_seed),
gain=0.001)
nn.init.xavier_normal_(self.output_projection.weight.data,
generator=torch.Generator().manual_seed(
config.random_seed),
gain=0.001)
nn.init.xavier_normal_(
self.qkv_projection.weight.data,
generator=torch.Generator().manual_seed(config.random_seed),
gain=0.001,
)
nn.init.xavier_normal_(
self.output_projection.weight.data,
generator=torch.Generator().manual_seed(config.random_seed),
gain=0.001,
)
def forward(
self,
@@ -144,7 +153,6 @@ class LlamaAttention(nn.Module):
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.self_attention = LlamaAttention(config)
@@ -164,7 +172,7 @@ class LlamaDecoderLayer(nn.Module):
- if residual is not None, the outputs are:
- residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
- hidden_states = (residual + 1) ** 2
""" # noqa
""" # noqa
if residual is None:
residual = hidden_states
hidden_states = hidden_states + 1
@@ -173,8 +181,9 @@ class LlamaDecoderLayer(nn.Module):
residual = hidden_states
hidden_states = hidden_states + 1
hidden_states = self.self_attention(positions=positions,
hidden_states=hidden_states)
hidden_states = self.self_attention(
positions=positions, hidden_states=hidden_states
)
hidden_states = hidden_states + residual
residual = hidden_states
@@ -186,20 +195,22 @@ class LlamaDecoderLayer(nn.Module):
@support_torch_compile
class LlamaModel(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
config: LlamaConfig,
prefix: str = '',
**kwargs) -> None:
def __init__(
self,
*,
vllm_config: VllmConfig,
config: LlamaConfig,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
self.embedding_tokens = nn.Embedding(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])
[LlamaDecoderLayer(config) for _ in range(config.num_layers)]
)
# this is the initial value of the hidden states
self.embedding_tokens.weight.data.fill_(config.init_value)
@@ -216,34 +227,39 @@ class LlamaModel(nn.Module):
return hidden_states
def tractable_computation(input_ids: torch.Tensor,
positions: torch.Tensor,
config: LlamaConfig,
init_value: float = 1.0) -> torch.Tensor:
hidden_states = torch.ones(input_ids.size(0),
config.hidden_size,
device=input_ids.device,
dtype=input_ids.dtype) * init_value
def tractable_computation(
input_ids: torch.Tensor,
positions: torch.Tensor,
config: LlamaConfig,
init_value: float = 1.0,
) -> torch.Tensor:
hidden_states = (
torch.ones(
input_ids.size(0),
config.hidden_size,
device=input_ids.device,
dtype=input_ids.dtype,
)
* init_value
)
# first layer
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
hidden_states = (residual + 1)**2
hidden_states = (residual + 1) ** 2
# following layers
for _ in range(config.num_layers - 1):
hidden_states = hidden_states + residual
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
hidden_states = (residual + 1)**2
hidden_states = (residual + 1) ** 2
return hidden_states
@torch.inference_mode
def run_model(llama_config,
use_compile: bool,
use_inductor: bool,
split_attn: bool = False) -> torch.Tensor:
def run_model(
llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False
) -> torch.Tensor:
if use_compile:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
@@ -256,54 +272,66 @@ def run_model(llama_config,
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, )
level=CompilationLevel.NO_COMPILATION,
)
cudagraph_runtime_mode = CUDAGraphMode.NONE
vllm_config = VllmConfig(compilation_config=compilation_config,
additional_config=llama_config)
vllm_config = VllmConfig(
compilation_config=compilation_config, additional_config=llama_config
)
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda()
model = (
LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
.eval()
.cuda()
)
with set_forward_context({},
vllm_config=vllm_config): # background context
with set_forward_context({}, vllm_config=vllm_config): # background context
B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
positions = torch.arange(B).cuda()
# warmup for the model with cudagraph_mode NONE
model(input_ids, positions)
# simulate cudagraphs capturing
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
model(input_ids[:2], positions[:2])
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=1, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=1,
),
):
model(input_ids[:1], positions[:1])
input_ids[:2].zero_()
# simulate cudagraphs replay
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
output = model(input_ids[:2], positions[:2])
output = output.cpu()
if llama_config.tractable_init:
expected_output = tractable_computation(input_ids[:2],
positions[:2],
llama_config).cpu()
expected_output = tractable_computation(
input_ids[:2], positions[:2], llama_config
).cpu()
assert torch.allclose(output, expected_output)
else:
@@ -314,27 +342,23 @@ def run_model(llama_config,
def test_toy_llama(use_inductor: bool):
# compare output with and without piecewise compilation
llama_config = LlamaConfig(hidden_size=128,
mlp_size=256,
vocab_size=128,
num_layers=12)
llama_config = LlamaConfig(
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=12
)
tractable_config = LlamaConfig(hidden_size=128,
mlp_size=256,
vocab_size=128,
num_layers=2,
tractable_init=True)
tractable_config = LlamaConfig(
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True
)
outputs = []
with compilation_counter.expect(
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_captured=0,
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_captured=0,
):
outputs.append(
run_model(llama_config, use_inductor=False, use_compile=False))
outputs.append(run_model(llama_config, use_inductor=False, use_compile=False))
run_model(tractable_config, use_inductor=False, use_compile=False)
if use_inductor:
@@ -343,41 +367,41 @@ def test_toy_llama(use_inductor: bool):
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=1,
num_piecewise_capturable_graphs_seen=1,
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
**kwargs,
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=1,
num_piecewise_capturable_graphs_seen=1,
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
**kwargs,
):
outputs.append(
run_model(llama_config,
use_inductor=use_inductor,
use_compile=True))
run_model(llama_config, use_inductor=use_inductor, use_compile=True)
)
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=2 * llama_config.num_layers +
1, # 2 * num_layers + 1
num_piecewise_capturable_graphs_seen=1 +
llama_config.num_layers, # 1 + num_layers
num_backend_compilations=1 +
llama_config.num_layers, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=2 *
(1 + llama_config.num_layers
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=2 * llama_config.num_layers + 1, # 2 * num_layers + 1
num_piecewise_capturable_graphs_seen=1
+ llama_config.num_layers, # 1 + num_layers
num_backend_compilations=1
+ llama_config.num_layers, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=2
* (
1 + llama_config.num_layers
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
outputs.append(
run_model(llama_config,
use_inductor=use_inductor,
use_compile=True,
split_attn=True))
run_model(tractable_config,
use_inductor=use_inductor,
use_compile=True,
split_attn=True)
run_model(
llama_config,
use_inductor=use_inductor,
use_compile=True,
split_attn=True,
)
)
run_model(
tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True
)
for i in range(1, len(outputs)):
assert torch.allclose(outputs[0], outputs[i])
@@ -388,17 +412,15 @@ def benchmark():
from triton.testing import do_bench
# similar to llama 3.1-8B
llama_config = LlamaConfig(hidden_size=4096,
mlp_size=14336,
vocab_size=128 * 1024,
num_layers=32)
llama_config = LlamaConfig(
hidden_size=4096, mlp_size=14336, vocab_size=128 * 1024, num_layers=32
)
# a tiny model to measure the overhead
# of piecewise cudagraph
llama_config = LlamaConfig(hidden_size=40,
mlp_size=80,
vocab_size=128,
num_layers=2)
llama_config = LlamaConfig(
hidden_size=40, mlp_size=80, vocab_size=128, num_layers=2
)
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
@@ -424,12 +446,15 @@ def benchmark():
vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda().to(torch.bfloat16)
model = (
LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
.eval()
.cuda()
.to(torch.bfloat16)
)
B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
positions = torch.arange(B).cuda().to(torch.bfloat16)
graphs = {}
@@ -451,21 +476,25 @@ def benchmark():
# and use it later, because it will look up the name `b` in the
# enclosing scope, and the value of `b` will always be 256.
# it is fine here, because we only use the lambda function once.
runtime = do_bench(lambda: graphs[b][0] # noqa
(input_ids[:b], positions[:b])) # noqa
runtime = do_bench(
lambda: graphs[b][0]( # noqa
input_ids[:b], positions[:b]
)
) # noqa
piecewise_cudagraph_time[b] = runtime
else:
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
eager_runtime = do_bench(
lambda: model(input_ids[:b], positions[:b])) # noqa
eager_runtime = do_bench(lambda: model(input_ids[:b], positions[:b])) # noqa
full_cudagraph_time[b] = runtime
eager_time[b] = eager_runtime
# print in tabular format
print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
for b in cudagraph_sizes:
print(f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
f"\t{piecewise_cudagraph_time[b]:.3f}")
print(
f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
f"\t{piecewise_cudagraph_time[b]:.3f}"
)
if __name__ == "__main__":

View File

@@ -31,8 +31,9 @@ def reset_global_counter():
_global_counter = 0
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
def silly_attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
) -> None:
"""
Unified attention implementation that depends on
all inputs and affects the output.
@@ -47,8 +48,9 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out.copy_(q + k + v)
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
def silly_attention_fake(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
) -> None:
"""Fake implementation for testing"""
return
@@ -60,5 +62,5 @@ direct_register_custom_op(
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
tags=(torch._C.Tag.cudagraph_unsafe, ),
tags=(torch._C.Tag.cudagraph_unsafe,),
)

View File

@@ -8,18 +8,30 @@ import torch
import vllm.envs as envs
from vllm.compilation.collective_fusion import AsyncTPPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
PassConfig, VllmConfig)
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter)
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.config import (
CompilationConfig,
DeviceConfig,
ModelConfig,
PassConfig,
VllmConfig,
)
from vllm.distributed import (
tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter,
)
from vllm.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import (compare_two_settings, create_new_process_for_each_test,
multi_gpu_test)
from ..utils import (
compare_two_settings,
create_new_process_for_each_test,
multi_gpu_test,
)
from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
@@ -33,21 +45,20 @@ prompts = [
class TestMMRSModel(torch.nn.Module):
def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__()
self.hidden_size = hidden_size
self.dtype = dtype
self.gate_proj = torch.nn.Parameter(torch.empty(
(self.hidden_size * 2, hidden_size)),
requires_grad=False)
self.gate_proj = torch.nn.Parameter(
torch.empty((self.hidden_size * 2, hidden_size)), requires_grad=False
)
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
def forward(self, hidden_states):
"""
Forward pass implementing the mm + reduce scatter in the FX graph
"""
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)
@@ -66,14 +77,13 @@ class TestMMRSModel(torch.nn.Module):
class TestAGMMModel(torch.nn.Module):
def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__()
self.hidden_size = hidden_size
self.dtype = dtype
self.weight = torch.nn.Parameter(torch.empty(
(hidden_size, hidden_size)),
requires_grad=False)
self.weight = torch.nn.Parameter(
torch.empty((hidden_size, hidden_size)), requires_grad=False
)
# Initialize weights
torch.nn.init.normal_(self.weight, std=0.02)
@@ -96,32 +106,35 @@ class TestAGMMModel(torch.nn.Module):
class _BaseScaledMMModel(torch.nn.Module):
def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__()
self.hidden_size = hidden_size
self.dtype = dtype
self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\
.contiguous().transpose(0, 1)
self.weight = (
torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
# Initialize scale_b for _scaled_mm.
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)
class TestScaledMMRSModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor):
"""
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
"""
fp8_input = input.to(FP8_DTYPE)
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
scaled_mm = torch._scaled_mm(fp8_input,
self.weight,
scale_a=scale_a,
scale_b=self.scale_b,
out_dtype=self.dtype)
scaled_mm = torch._scaled_mm(
fp8_input,
self.weight,
scale_a=scale_a,
scale_b=self.scale_b,
out_dtype=self.dtype,
)
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
return reduce_scatter
@@ -133,7 +146,6 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
class TestAGScaledMMModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor):
"""
Forward pass implementing the all gather + scaled_mm in the FX graph
@@ -143,11 +155,13 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
scaled_mm = torch._scaled_mm(all_gather,
self.weight,
scale_a=scale_a,
scale_b=self.scale_b,
out_dtype=self.dtype)
scaled_mm = torch._scaled_mm(
all_gather,
self.weight,
scale_a=scale_a,
scale_b=self.scale_b,
out_dtype=self.dtype,
)
return scaled_mm
def ops_in_model_before(self):
@@ -158,20 +172,22 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor):
"""
Forward pass implementing the cutlass_scaled_mm + reduce scatter
in the FX graph
"""
fp8_input = input.to(FP8_DTYPE)
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]),
dtype=self.dtype,
device=input.device)
torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a,
self.scale_b, None)
mm_out = torch.empty(
(fp8_input.shape[0], self.weight.shape[1]),
dtype=self.dtype,
device=input.device,
)
torch.ops._C.cutlass_scaled_mm(
mm_out, fp8_input, self.weight, scale_a, self.scale_b, None
)
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
return reduce_scatter
@@ -183,10 +199,9 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor):
"""
Forward pass implementing the all gather + cutlass_scaled_mm
Forward pass implementing the all gather + cutlass_scaled_mm
in the FX graph
"""
# Reshape input
@@ -195,11 +210,14 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]),
dtype=self.dtype,
device=all_gather.device)
torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight,
scale_a, self.scale_b, None)
mm_out = torch.empty(
(all_gather.shape[0], self.weight.shape[1]),
dtype=self.dtype,
device=all_gather.device,
)
torch.ops._C.cutlass_scaled_mm(
mm_out, all_gather, self.weight, scale_a, self.scale_b, None
)
return mm_out
def ops_in_model_before(self):
@@ -210,23 +228,37 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("test_model", [
TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel,
TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel
])
@pytest.mark.parametrize(
"test_model",
[
TestMMRSModel,
TestAGMMModel,
TestScaledMMRSModel,
TestAGScaledMMModel,
TestCutlassScaledMMRSModel,
TestAGCutlassScaledMMModel,
],
)
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
if test_model in (TestScaledMMRSModel, TestAGScaledMMModel,
TestCutlassScaledMMRSModel,
TestAGCutlassScaledMMModel) and dtype == torch.float16:
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_async_tp_pass_replace(
test_model: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype
):
if (
test_model
in (
TestScaledMMRSModel,
TestAGScaledMMModel,
TestCutlassScaledMMRSModel,
TestAGCutlassScaledMMModel,
)
and dtype == torch.float16
):
pytest.skip(
"Only bf16 high precision output types are supported for " \
"Only bf16 high precision output types are supported for "
"per-token (row-wise) scaling"
)
@@ -235,19 +267,24 @@ def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
def run_torch_spawn(fn, nprocs):
# need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda
torch.multiprocessing.spawn(fn,
args=(num_processes, test_model,
batch_size, seq_len, hidden_size,
dtype),
nprocs=nprocs)
torch.multiprocessing.spawn(
fn,
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
nprocs=nprocs,
)
run_torch_spawn(async_tp_pass_on_test_model, num_processes)
def async_tp_pass_on_test_model(local_rank: int, world_size: int,
test_model_cls: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
def async_tp_pass_on_test_model(
local_rank: int,
world_size: int,
test_model_cls: torch.nn.Module,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
):
current_platform.seed_everything(0)
device = torch.device(f"cuda:{local_rank}")
@@ -255,13 +292,15 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})
update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
}
)
# initialize distributed
init_distributed_environment()
@@ -269,27 +308,28 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
# configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
enable_async_tp=True, ), )
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(
enable_async_tp=True,
),
)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config.model_config = ModelConfig(model=model_name,
trust_remote_code=True,
dtype=dtype,
seed=42)
vllm_config.model_config = ModelConfig(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
)
async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass)
model = test_model_cls(hidden_size,
dtype) # Pass dtype to model constructor
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype,
requires_grad=False)
hidden_states = torch.randn(
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
)
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
@@ -306,10 +346,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
@create_new_process_for_each_test()
@pytest.mark.parametrize("model_id", [
"meta-llama/Llama-3.2-1B-Instruct",
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
])
@pytest.mark.parametrize(
"model_id",
["meta-llama/Llama-3.2-1B-Instruct", "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"],
)
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("async_tp_enabled", [True])
@pytest.mark.parametrize("distributed_backend", ["mp"])
@@ -342,12 +382,10 @@ def test_async_tp_pass_correctness(
common_args.append("--enforce-eager")
compilation_config = {
'level': 3,
'compile_sizes': [2, 4, 8],
'splitting_ops': [],
'pass_config': {
'enable_async_tp': async_tp_enabled
},
"level": 3,
"compile_sizes": [2, 4, 8],
"splitting_ops": [],
"pass_config": {"enable_async_tp": async_tp_enabled},
}
async_tp_env = tp_env = {
@@ -372,9 +410,6 @@ def test_async_tp_pass_correctness(
"mp",
]
compare_two_settings(model_id,
async_tp_args,
tp_args,
async_tp_env,
tp_env,
method="generate")
compare_two_settings(
model_id, async_tp_args, tp_args, async_tp_env, tp_env, method="generate"
)

View File

@@ -103,23 +103,28 @@ def test_compile_correctness(
attn_backend = test_setting.attn_backend
method = test_setting.method
if cuda_device_count_stateless() < pp_size * tp_size:
pytest.skip(f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
f"{cuda_device_count_stateless()}")
pytest.skip(
f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
f"{cuda_device_count_stateless()}"
)
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
final_args = [
"--enforce-eager", *model_args, "-pp",
str(pp_size), "-tp",
str(tp_size)
"--enforce-eager",
*model_args,
"-pp",
str(pp_size),
"-tp",
str(tp_size),
]
all_args: list[list[str]] = []
all_envs: list[dict[str, str] | None] = []
for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.PIECEWISE,
CompilationLevel.NO_COMPILATION,
CompilationLevel.PIECEWISE,
]:
all_args.append(final_args + [f"-O{level}"])
all_envs.append({})
@@ -130,14 +135,15 @@ def test_compile_correctness(
model,
all_args,
all_envs,
method=method if method != "generate" else "generate_close")
method=method if method != "generate" else "generate_close",
)
all_envs.clear()
all_args.clear()
for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
]:
all_args.append(final_args + [f"-O{level}"])
all_envs.append({})

View File

@@ -9,11 +9,11 @@ from vllm.utils import _is_torch_equal_or_newer
def test_version():
assert _is_torch_equal_or_newer('2.8.0.dev20250624+cu128', '2.8.0.dev')
assert _is_torch_equal_or_newer('2.8.0a0+gitc82a174', '2.8.0.dev')
assert _is_torch_equal_or_newer('2.8.0', '2.8.0.dev')
assert _is_torch_equal_or_newer('2.8.1', '2.8.0.dev')
assert not _is_torch_equal_or_newer('2.7.1', '2.8.0.dev')
assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
def test_use_cudagraphs_dynamic(monkeypatch):
@@ -21,7 +21,7 @@ def test_use_cudagraphs_dynamic(monkeypatch):
vllm_config = VllmConfig()
assert vllm_config.compilation_config.use_cudagraph
monkeypatch.setenv('VLLM_USE_V1', '0')
monkeypatch.setenv("VLLM_USE_V1", "0")
vllm_config = VllmConfig()
assert not vllm_config.compilation_config.use_cudagraph
@@ -44,19 +44,23 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
assert vllm.envs.VLLM_USE_V1
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val)
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
compilation_config = {
"use_cudagraph": False, # speed things up a bit
}
with (
compilation_counter.expect(num_cache_entries_updated=0,
num_compiled_artifacts_saved=0),
# loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m',
compilation_config=compilation_config,
gpu_memory_utilization=0.4) as _):
compilation_counter.expect(
num_cache_entries_updated=0, num_compiled_artifacts_saved=0
),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
"facebook/opt-125m",
compilation_config=compilation_config,
gpu_memory_utilization=0.4,
) as _,
):
pass
@@ -67,22 +71,25 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
assert vllm.envs.VLLM_USE_V1
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
compilation_config = {
"cudagraph_capture_sizes": [100],
"use_cudagraph": enabled,
}
with (
compilation_counter.expect(
num_graphs_seen=1,
num_gpu_runner_capture_triggers=1 if enabled else 0,
num_cudagraph_captured=13 if enabled else 0,
),
# loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m',
compilation_config=compilation_config,
gpu_memory_utilization=0.4) as _):
compilation_counter.expect(
num_graphs_seen=1,
num_gpu_runner_capture_triggers=1 if enabled else 0,
num_cudagraph_captured=13 if enabled else 0,
),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
"facebook/opt-125m",
compilation_config=compilation_config,
gpu_memory_utilization=0.4,
) as _,
):
pass
@@ -90,14 +97,17 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
@pytest.mark.forked
def test_dynamo_as_is(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with (
compilation_counter.expect(dynamo_as_is_count=1),
# loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m',
compilation_config={"level": 1},
gpu_memory_utilization=0.4) as _):
compilation_counter.expect(dynamo_as_is_count=1),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
"facebook/opt-125m",
compilation_config={"level": 1},
gpu_memory_utilization=0.4,
) as _,
):
pass
@@ -105,14 +115,16 @@ def test_dynamo_as_is(vllm_runner, monkeypatch):
@pytest.mark.forked
def test_no_compilation(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with (
compilation_counter.expect(num_graphs_seen=0,
dynamo_as_is_count=0),
# loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m',
compilation_config={"level": 0},
gpu_memory_utilization=0.4) as _):
compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
"facebook/opt-125m",
compilation_config={"level": 0},
gpu_memory_utilization=0.4,
) as _,
):
pass
@@ -120,77 +132,73 @@ def test_no_compilation(vllm_runner, monkeypatch):
@pytest.mark.forked
def test_enforce_eager(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with (
compilation_counter.expect(num_graphs_seen=0,
dynamo_as_is_count=0),
# loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m',
enforce_eager=True,
gpu_memory_utilization=0.4) as _):
compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
"facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
) as _,
):
pass
def test_splitting_ops_dynamic():
# Default config
config = VllmConfig()
assert config.compilation_config.cudagraph_mode == \
CUDAGraphMode.FULL_AND_PIECEWISE
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
assert config.compilation_config.splitting_ops_contain_attention()
# When use_inductor_graph_partition=True
if _is_torch_equal_or_newer('2.9.0.dev'):
if _is_torch_equal_or_newer("2.9.0.dev"):
# inductor graph partition is only available in PyTorch 2.9+.
# this is a fast config check so we are not using pytest.skip.
config = VllmConfig(compilation_config=CompilationConfig(
use_inductor_graph_partition=True,
splitting_ops=["silly_attention"]))
config = VllmConfig(
compilation_config=CompilationConfig(
use_inductor_graph_partition=True, splitting_ops=["silly_attention"]
)
)
# should ignore splitting_ops
assert config.compilation_config.splitting_ops == []
# When attn_fusion pass enabled.
config = VllmConfig(compilation_config=CompilationConfig(
pass_config={
"enable_attn_fusion": True,
"enable_noop": True
},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
))
config = VllmConfig(
compilation_config=CompilationConfig(
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
assert config.compilation_config.splitting_ops == []
# cudagraph mode also fall back to FULL
assert config.compilation_config.cudagraph_mode == \
CUDAGraphMode.FULL
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
# splitting_ops can not contain attention ops when attn_fusion
# pass enabled.
with pytest.raises(AssertionError):
config = VllmConfig(compilation_config=CompilationConfig(
pass_config={
"enable_attn_fusion": True,
"enable_noop": True
},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
# work around for accessing all attntion ops
splitting_ops=CompilationConfig()._attention_ops,
))
config = VllmConfig(
compilation_config=CompilationConfig(
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
# work around for accessing all attntion ops
splitting_ops=CompilationConfig()._attention_ops,
)
)
# When both use_inductor_graph_partition and attn_fusion pass enabled.
if _is_torch_equal_or_newer('2.9.0.dev'):
config = VllmConfig(compilation_config=CompilationConfig(
use_inductor_graph_partition=True,
pass_config={
"enable_attn_fusion": True,
"enable_noop": True
},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
))
if _is_torch_equal_or_newer("2.9.0.dev"):
config = VllmConfig(
compilation_config=CompilationConfig(
use_inductor_graph_partition=True,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
assert config.compilation_config.splitting_ops == []
# enable_attn_fusion is directly support under
# use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged.
assert config.compilation_config.cudagraph_mode == \
CUDAGraphMode.PIECEWISE
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE

View File

@@ -4,10 +4,15 @@ import torch
from torch import nn
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import (ignore_torch_compile,
support_torch_compile)
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
CUDAGraphMode, VllmConfig, set_current_vllm_config)
from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
from vllm.config import (
CacheConfig,
CompilationConfig,
CompilationLevel,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
# This import automatically registers `torch.ops.silly.attention`
@@ -18,32 +23,42 @@ MLP_SIZE = 128
@torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module,
cudagraph_runtime_mode: CUDAGraphMode):
def run_model(
vllm_config: VllmConfig, model: nn.Module, cudagraph_runtime_mode: CUDAGraphMode
):
with set_forward_context({}, vllm_config=vllm_config):
# warmup for the model with cudagraph_mode NONE
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
# simulate cudagraphs capturing
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
model(torch.randn(2, MLP_SIZE).cuda())
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=1, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=1,
),
):
model(torch.randn(1, MLP_SIZE).cuda())
# simulate cudagraphs replay
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
output = model(torch.randn(2, MLP_SIZE).cuda())
output = output.cpu()
@@ -52,22 +67,21 @@ def run_model(vllm_config: VllmConfig, model: nn.Module,
def test_ignore_torch_compile_decorator():
# piecewise
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
))
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
@support_torch_compile
class A(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
def __init__(
self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs
) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -79,66 +93,60 @@ def test_ignore_torch_compile_decorator():
return x
@ignore_torch_compile
class B(A):
...
class B(A): ...
@support_torch_compile
class C(B):
...
class C(B): ...
with set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
# A has support_torch_compile
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=3,
num_piecewise_capturable_graphs_seen=2,
num_backend_compilations=2,
num_cudagraph_captured=4,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_graphs_seen=1,
num_piecewise_graphs_seen=3,
num_piecewise_capturable_graphs_seen=2,
num_backend_compilations=2,
num_cudagraph_captured=4,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
with set_current_vllm_config(vllm_config):
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
mod_B = B(vllm_config=vllm_config, prefix="").eval().cuda()
# B's ignore_torch_compile should override A's support_torch_compile
with compilation_counter.expect(
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_captured=0,
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_captured=0,
):
run_model(vllm_config, mod_B, cudagraph_runtime_mode)
with set_current_vllm_config(vllm_config):
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
mod_C = C(vllm_config=vllm_config, prefix="").eval().cuda()
# C's support_torch_compile should override B's ignore_torch_compile
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=3,
num_piecewise_capturable_graphs_seen=2,
num_backend_compilations=2,
num_cudagraph_captured=4,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_graphs_seen=1,
num_piecewise_graphs_seen=3,
num_piecewise_capturable_graphs_seen=2,
num_backend_compilations=2,
num_cudagraph_captured=4,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
# Only enable torch.compile if
# vllm_config.cache_config.kv_sharing_fast_prefill=True
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
kv_sharing_fast_prefill)
@support_torch_compile(
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
)
class B(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -152,15 +160,11 @@ class B(nn.Module):
# Only enable torch.compile if
# vllm_config.cache_config.kv_sharing_fast_prefill=False
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
cache_config.kv_sharing_fast_prefill)
@support_torch_compile(
enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
)
class A(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
super().__init__()
self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
@@ -175,54 +179,60 @@ class A(nn.Module):
def test_conditional_compile_enable_if():
vllm_config = VllmConfig(cache_config=CacheConfig(
kv_sharing_fast_prefill=True, ),
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
))
vllm_config = VllmConfig(
cache_config=CacheConfig(
kv_sharing_fast_prefill=True,
),
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
),
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
with set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
# A has support_torch_compile but enable_if fn returns False
# enalbe_if will be True for B, so we expect mod1 and mod2
# to be compiled
with compilation_counter.expect(
num_graphs_seen=2,
num_piecewise_graphs_seen=6,
# 3 piecewise graphs per instance of B()
num_piecewise_capturable_graphs_seen=4,
num_backend_compilations=4,
num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_graphs_seen=2,
num_piecewise_graphs_seen=6,
# 3 piecewise graphs per instance of B()
num_piecewise_capturable_graphs_seen=4,
num_backend_compilations=4,
num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
# Set kv_sharing_fast_prefill=False
# which will cause A to be compiled and B to not be compiled
vllm_config = VllmConfig(cache_config=CacheConfig(
kv_sharing_fast_prefill=False, ),
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
))
vllm_config = VllmConfig(
cache_config=CacheConfig(
kv_sharing_fast_prefill=False,
),
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
),
)
with set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=7,
# 3 attn ops and 4 non-attn ops
num_piecewise_capturable_graphs_seen=4,
num_backend_compilations=4,
num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_graphs_seen=1,
num_piecewise_graphs_seen=7,
# 3 attn ops and 4 non-attn ops
num_piecewise_capturable_graphs_seen=4,
num_backend_compilations=4,
num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
run_model(vllm_config, mod_A, cudagraph_runtime_mode)

View File

@@ -14,8 +14,7 @@ from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
PassConfig)
from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
@@ -25,43 +24,54 @@ from ..utils import create_new_process_for_each_test
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
"dtype": torch.float16,
}),
("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", {
"dtype": torch.float16,
}),
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
{
"dtype": torch.float16,
},
),
(
"neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic",
{
"dtype": torch.float16,
},
),
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
("meta-llama/Llama-3.2-1B-Instruct", {}),
]
if all:
# TODO: figure out why this fails.
if False and is_quant_method_supported("gguf"): # noqa: SIM223
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
"quantization": "gguf"
}))
TEST_MODELS.append(
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {"quantization": "gguf"})
)
if is_quant_method_supported("gptq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
"quantization": "gptq"
}))
TEST_MODELS.append(
("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {"quantization": "gptq"})
)
if is_quant_method_supported("gptq_marlin"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
"quantization": "gptq_marlin"
}))
TEST_MODELS.append(
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ",
{"quantization": "gptq_marlin"},
)
)
if is_quant_method_supported("gptq_marlin_24"):
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
"quantization": "gptq_marlin_24"
}))
TEST_MODELS.append(
(
"alexm-nm/tinyllama-24-marlin24-4bit-g128",
{"quantization": "gptq_marlin_24"},
)
)
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
"quantization": "AWQ"
}))
TEST_MODELS.append(
("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {"quantization": "AWQ"})
)
if keywords is None:
return TEST_MODELS
@@ -95,22 +105,34 @@ def test_full_graph(
"compilation_config, model_info",
[
# additional compile sizes, only some of the models
(CompilationConfig(level=CompilationLevel.PIECEWISE,
compile_sizes=[1, 2]), model)
(
CompilationConfig(level=CompilationLevel.PIECEWISE, compile_sizes=[1, 2]),
model,
)
for model in models_list(all=False)
] + [
]
+ [
# RMSNorm + quant fusion, only 8-bit quant models
(CompilationConfig(level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm"],
pass_config=PassConfig(enable_fusion=True,
enable_noop=True)), model)
(
CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
),
model,
)
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
] + [
]
+ [
# Test depyf integration works
(CompilationConfig(level=CompilationLevel.PIECEWISE,
debug_dump_path=tempfile.gettempdir()),
("facebook/opt-125m", {})),
] + [
(
CompilationConfig(
level=CompilationLevel.PIECEWISE, debug_dump_path=tempfile.gettempdir()
),
("facebook/opt-125m", {}),
),
]
+ [
# graph inductor partition
(
CompilationConfig(
@@ -119,20 +141,24 @@ def test_full_graph(
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
compile_sizes=[1, 2]),
model) for model in models_list(all=False)
compile_sizes=[1, 2],
),
model,
)
for model in models_list(all=False)
if is_torch_equal_or_newer("2.9.0.dev")
])
],
)
# only test some of the models
@create_new_process_for_each_test()
def test_custom_compile_config(
compilation_config: CompilationConfig,
model_info: tuple[str, dict[str, Any]],
):
if (compilation_config.use_inductor_graph_partition
and not is_torch_equal_or_newer("2.9.0.dev")):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer(
"2.9.0.dev"
):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
model, model_kwargs = model_info
print(f"MODEL={model}")
@@ -156,8 +182,7 @@ def test_fp8_kv_scale_compile(optimization_level: int):
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
compilation_config = CompilationConfig(
@@ -171,14 +196,16 @@ def test_inductor_graph_partition_attn_fusion(caplog_vllm):
"kv_cache_dtype": "fp8",
"max_model_len": 1024,
}
with caplog_vllm.at_level(
logging.DEBUG), global_force_attn_backend_context_manager(
_Backend.FLASHINFER):
with (
caplog_vllm.at_level(logging.DEBUG),
global_force_attn_backend_context_manager(_Backend.FLASHINFER),
):
run_model(compilation_config, model, model_kwargs)
try:
assert ("Fused quantization onto 48 attention nodes"
in caplog_vllm.text), caplog_vllm.text
assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, (
caplog_vllm.text
)
except AssertionError:
# Note: this message is only triggered when the compilation goes
# through the custom pass. Due to multiple layers of cache on
@@ -189,8 +216,11 @@ def test_inductor_graph_partition_attn_fusion(caplog_vllm):
assert "Fused quantization" not in caplog_vllm.text
def run_model(compile_config: Union[int, CompilationConfig], model: str,
model_kwargs: dict[str, Any]):
def run_model(
compile_config: Union[int, CompilationConfig],
model: str,
model_kwargs: dict[str, Any],
):
prompts = [
"Hello, my name is",
"The president of the United States is",

View File

@@ -14,10 +14,8 @@ from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
@@ -28,7 +26,6 @@ FP8_DTYPE = current_platform.fp8_dtype()
class TestSiluMul(torch.nn.Module):
def __init__(self, hidden_size: int = 128):
super().__init__()
self.silu_and_mul = SiluAndMul()
@@ -36,8 +33,7 @@ class TestSiluMul(torch.nn.Module):
self.scale = torch.rand(1, dtype=torch.float32)
if TEST_FP8:
self.w = torch.rand(hidden_size,
hidden_size).to(dtype=FP8_DTYPE).t()
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
@@ -46,17 +42,14 @@ class TestSiluMul(torch.nn.Module):
def forward(self, x):
y = self.silu_and_mul(x)
if TEST_FP8:
x2 = self.fp8_linear.apply(y,
self.w,
self.wscale,
input_scale=self.wscale)
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
return x2
else:
return y
def example_inputs(self, num_tokens=32, hidden_size=128):
dtype = torch.float16 if TEST_FP8 else torch.float32
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype), )
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),)
def ops_in_model(self, do_fusion):
if TEST_FP8 and do_fusion:
@@ -69,7 +62,6 @@ class TestSiluMul(torch.nn.Module):
class TestFusedAddRMSNorm(torch.nn.Module):
def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
@@ -78,10 +70,12 @@ class TestFusedAddRMSNorm(torch.nn.Module):
dtype = torch.float16 if TEST_FP8 else torch.float32
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size), dtype=dtype))
torch.empty((intermediate_size, hidden_size), dtype=dtype)
)
self.norm = RMSNorm(intermediate_size, 1e-05)
self.norm.weight = torch.nn.Parameter(
torch.ones(intermediate_size, dtype=dtype))
torch.ones(intermediate_size, dtype=dtype)
)
torch.nn.init.normal_(self.gate_proj, std=0.02)
@@ -89,8 +83,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
self.scale = torch.rand(1, dtype=torch.float32)
self.w = torch.rand(hidden_size,
intermediate_size).to(dtype=FP8_DTYPE).t()
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
self.wscale = torch.rand(1, dtype=torch.float32)
def forward(self, hidden_states, residual):
@@ -120,10 +113,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
dtype = torch.float16 if TEST_FP8 else torch.float32
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype)
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
return (hidden_states, residual)
def ops_in_model(self, do_fusion):
@@ -137,12 +128,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
class TestRotaryEmbedding(torch.nn.Module):
def __init__(self,
head_dim=64,
rotary_dim=None,
max_position=2048,
base=10000):
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
super().__init__()
self.head_dim = head_dim
self.rotary_dim = rotary_dim or head_dim
@@ -173,21 +159,15 @@ class TestRotaryEmbedding(torch.nn.Module):
class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
def __init__(self,
head_dim=64,
num_heads=4,
max_position=2048,
base=10000):
def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000):
super().__init__()
self.head_dim = head_dim
self.num_heads = num_heads
self.hidden_size = head_dim * num_heads
self.qkv_proj = torch.nn.Linear(self.hidden_size,
self.hidden_size * 3,
bias=False,
dtype=torch.float16)
self.qkv_proj = torch.nn.Linear(
self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16
)
self.rotary_emb = get_rope(
self.head_dim,
@@ -233,21 +213,24 @@ MODELS = [
@pytest.mark.parametrize("model_class", MODELS)
@pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
torch.set_default_device("cuda")
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)
)
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = ([noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
if do_fusion else [noop_pass, cleanup_pass])
passes = (
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
if do_fusion
else [noop_pass, cleanup_pass]
)
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
@@ -260,8 +243,7 @@ def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
# check if the functionalization pass is applied
for op in model.ops_in_model(do_fusion):
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert (find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op)
is None) # noqa: E501
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # noqa: E501
# make sure the ops were all de-functionalized
found = dict()

View File

@@ -5,17 +5,26 @@ import pytest
import torch
import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
RMSNormQuantFusionPass)
from vllm.compilation.fusion import (
FUSED_OPS,
QUANT_OPS,
FusedRMSQuantKey,
RMSNormQuantFusionPass,
)
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, QuantKey, ScaleDesc)
GroupShape,
QuantKey,
ScaleDesc,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity)
Fp8LinearOp,
cutlass_fp8_supported,
maybe_create_device_identity,
)
from vllm.platforms import current_platform
from ..utils import override_cutlass_fp8_supported
@@ -25,9 +34,15 @@ FP8_DTYPE = current_platform.fp8_dtype()
class TestModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, static: bool,
cuda_force_torch: bool, *args, **kwargs):
def __init__(
self,
hidden_size: int,
eps: float,
static: bool,
cuda_force_torch: bool,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.cuda_force_torch = cuda_force_torch
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
@@ -54,17 +69,15 @@ class TestModel(torch.nn.Module):
resid = torch.sqrt(x)
y = self.norm[0](x)
x2 = self.fp8_linear.apply(y,
self.w[0],
self.wscale[0],
input_scale=self.scale[0])
x2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)
x3 = self.fp8_linear.apply(y2,
self.w[1],
self.wscale[1],
input_scale=self.scale[1])
x3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
y3, resid = self.norm[2](x3, resid) # use resid here
return y3
@@ -74,7 +87,7 @@ class TestModel(torch.nn.Module):
def ops_in_model_after(self):
return [
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
FUSED_OPS[FusedRMSQuantKey(self.key, True)]
FUSED_OPS[FusedRMSQuantKey(self.key, True)],
]
@@ -85,22 +98,27 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("static", [True, False])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch",
[True, False] if cutlass_fp8_supported() else [True])
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test on CUDA and ROCm")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cuda_force_torch):
@pytest.mark.parametrize(
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
def test_fusion_rmsnorm_quant(
dtype, hidden_size, num_tokens, eps, static, cuda_force_torch
):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm", "+quant_fp8"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
))
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm", "+quant_fp8"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
)
)
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)

View File

@@ -10,14 +10,24 @@ from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
ModelConfig, PassConfig, VllmConfig)
from vllm.config import (
CompilationConfig,
CompilationLevel,
DeviceConfig,
ModelConfig,
PassConfig,
VllmConfig,
)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
GroupShape, QuantFP8)
GroupShape,
QuantFP8,
)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
@@ -26,7 +36,6 @@ from .backend import TestBackend
class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
@@ -47,7 +56,6 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
@@ -68,25 +76,22 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.quant_fp8 = QuantFP8(static=True,
group_shape=GroupShape.PER_TENSOR)
self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size),
dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual)
torch.ops._C.static_scaled_fp8_quant(self.output,
norm_output.contiguous(),
self.scale)
torch.ops._C.static_scaled_fp8_quant(
self.output, norm_output.contiguous(), self.scale
)
return self.output, residual_output
def ops_in_model_after(self):
@@ -95,35 +100,33 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default
torch.ops._C.static_scaled_fp8_quant.default,
]
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size),
dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(token_num, 128)
scale_n = hidden_size // 16
rounded_n = round_up(scale_n, 4)
self.output_scale = torch.empty((rounded_m, rounded_n // 4),
dtype=torch.int32)
self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32)
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual)
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
torch.ops._C.scaled_fp4_quant(self.output, norm_output,
self.output_scale, self.scale)
torch.ops._C.scaled_fp4_quant(
self.output, norm_output, self.output_scale, self.scale
)
return self.output, residual_output, self.output_scale
def ops_in_model_after(self):
@@ -132,7 +135,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.scaled_fp4_quant.default
torch.ops._C.scaled_fp4_quant.default,
]
@@ -145,41 +148,55 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
# TODO: Enable with torch==2.8.0
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
])
],
)
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [8])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
@pytest.mark.skipif(
not find_spec("flashinfer")
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
reason="flashinfer is not found or flashinfer "
"is not compiled with trtllm_allreduce_fusion")
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
"is not compiled with trtllm_allreduce_fusion",
)
def test_all_reduce_fusion_pass_replace(
test_model: torch.nn.Module,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
):
num_processes = 2
if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
and not current_platform.has_device_capability(100)):
pytest.skip("Skip as nvfp4 is only supported on "
"devices with compute capability 10.0 (Blackwell)")
if (
test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
and not current_platform.has_device_capability(100)
):
pytest.skip(
"Skip as nvfp4 is only supported on "
"devices with compute capability 10.0 (Blackwell)"
)
def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn(fn,
args=(num_processes, test_model,
batch_size, seq_len, hidden_size,
dtype),
nprocs=nprocs)
torch.multiprocessing.spawn(
fn,
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
nprocs=nprocs,
)
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
test_model_cls: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
def all_reduce_fusion_pass_on_test_model(
local_rank: int,
world_size: int,
test_model_cls: torch.nn.Module,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
):
current_platform.seed_everything(0)
device = torch.device(f"cuda:{local_rank}")
@@ -187,39 +204,42 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})
update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
}
)
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm", "+quant_fp8"]))
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm", "+quant_fp8"]
)
)
vllm_config.compilation_config.pass_config = PassConfig(
enable_fi_allreduce_fusion=True, enable_noop=True)
enable_fi_allreduce_fusion=True, enable_noop=True
)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config.model_config = ModelConfig(model=model_name,
trust_remote_code=True,
dtype=dtype,
seed=42)
vllm_config.model_config = ModelConfig(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
)
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass,
cleanup_pass)
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass)
token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)

View File

@@ -19,14 +19,23 @@ from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
set_current_vllm_config)
from vllm.config import (
CacheConfig,
CompilationConfig,
CompilationLevel,
ModelConfig,
PassConfig,
SchedulerConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
QuantKey,
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.v1.kv_cache_interface import AttentionSpec
@@ -40,14 +49,16 @@ backend_unfused: Optional[TestBackend] = None
@pytest.mark.parametrize(
"model, quant_key",
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
"model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]
)
@pytest.mark.parametrize("use_triton_fa", [True, False])
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="V0 attn quant fusion only on ROCm")
def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
quant_key: QuantKey, use_triton_fa: bool):
@pytest.mark.skipif(
not current_platform.is_rocm(), reason="V0 attn quant fusion only on ROCm"
)
def test_attention_fusion_v0(
example_prompts, monkeypatch, model: str, quant_key: QuantKey, use_triton_fa: bool
):
# Clean Dynamo cache to avoid reusing other test cases
# (for some reason the reset at the end is not enough)
torch._dynamo.reset()
@@ -69,22 +80,24 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
backend="tests.compile.test_fusion_attn.backend_unfused",
custom_ops=["+quant_fp8"],
)
vllm_config = VllmConfig(compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
))
vllm_config = VllmConfig(
compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
),
)
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
llm = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.5,
max_model_len=2048)
llm = LLM(
model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.5,
max_model_len=2048,
)
sampling_params = SamplingParams(temperature=0.0,
max_tokens=10,
top_p=0.95)
sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_p=0.95)
unfused_output = llm.generate(prompts, sampling_params)
backend_unfused = None # Reset backend to make sure llm gets released
@@ -97,21 +110,25 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
backend="tests.compile.test_fusion_attn.backend",
custom_ops=["+quant_fp8"],
)
vllm_config = VllmConfig(compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
))
vllm_config = VllmConfig(
compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
),
)
# AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation.
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
llm2 = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.5,
max_model_len=2048)
llm2 = LLM(
model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.5,
max_model_len=2048,
)
# check support
attn_fusion_supported = [
@@ -132,9 +149,9 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
for i in range(len(attn_nodes_pre)):
assert attn_nodes_pre[i].kwargs["output_scale"] is None
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
assert fused == attn_fusion_supported[i], \
f"Node {i} {'' if fused else 'not '} expected " \
f"to have fused output quant"
assert fused == attn_fusion_supported[i], (
f"Node {i} {'' if fused else 'not '} expected to have fused output quant"
)
# check outputs
fused_output = llm2.generate(prompts, sampling_params)
@@ -160,9 +177,16 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
class AttentionQuantPatternModel(torch.nn.Module):
"""Base model for AttentionQuantPattern fusion."""
def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
kv_cache_dtype: torch.dtype, device: torch.device,
vllm_config: VllmConfig, **kwargs):
def __init__(
self,
num_qo_heads: int,
num_kv_heads: int,
head_size: int,
kv_cache_dtype: torch.dtype,
device: torch.device,
vllm_config: VllmConfig,
**kwargs,
):
super().__init__()
self.num_qo_heads = num_qo_heads
self.num_kv_heads = num_kv_heads
@@ -197,33 +221,30 @@ class AttentionQuantPatternModel(torch.nn.Module):
device=self.device,
)
def build_attn_metadata(self, batch_size: int, use_hnd: bool) \
-> AttentionMetadata:
def build_attn_metadata(self, batch_size: int, use_hnd: bool) -> AttentionMetadata:
"""Initialize attention metadata."""
# Create common attn metadata
batch_spec = BatchSpec(seq_lens=[1] * batch_size,
query_lens=[1] * batch_size)
batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
self.block_size,
self.device,
arange_block_indices=True)
batch_spec, self.block_size, self.device, arange_block_indices=True
)
max_blocks = (max(batch_spec.seq_lens) + self.block_size -
1) // self.block_size
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks
# Create dummy KV cache for FlashInfer TRTLLM
# - NHD: [num_blocks, block_size, num_kv_heads, head_size]
# - HND: [num_blocks, num_kv_heads, block_size, head_size]
kv_cache = torch.zeros(num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device)
kv_cache = torch.zeros(
num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
if current_platform.is_rocm():
# k/v as 1st dimention
if use_hnd:
@@ -239,7 +260,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
# Build attn metadata
self.attn_metadata = self.builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata)
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
return self.attn_metadata
@@ -254,27 +276,30 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.quant_key.scale.static,
act_quant_group_shape=self.quant_key.scale.group_shape)
act_quant_group_shape=self.quant_key.scale.group_shape,
)
hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get(
"w", {
"weight":
torch.randn(hidden_size, hidden_size).to(
dtype=FP8_DTYPE, device=self.device).t(),
"wscale":
torch.tensor([1.0], dtype=torch.float32, device=self.device),
"scale":
torch.tensor([1.0], dtype=torch.float32, device=self.device),
})
"w",
{
"weight": torch.randn(hidden_size, hidden_size)
.to(dtype=FP8_DTYPE, device=self.device)
.t(),
"wscale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
"scale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
},
)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v)
return self.fp8_linear.apply(input=attn_output,
weight=self.w["weight"],
weight_scale=self.w["wscale"],
input_scale=self.w["scale"])
return self.fp8_linear.apply(
input=attn_output,
weight=self.w["weight"],
weight_scale=self.w["wscale"],
input_scale=self.w["scale"],
)
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
@@ -287,42 +312,54 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get(
"w", {
"weight":
torch.randint(256, (hidden_size, hidden_size // 2),
dtype=FP4_DTYPE,
device=self.device),
"wscale_swizzled":
torch.randn(hidden_size, hidden_size // 16).to(
dtype=FP8_DTYPE, device=self.device),
"wscale":
torch.tensor([500], dtype=torch.float32, device=self.device),
"scale":
torch.tensor([0.002], dtype=torch.float32, device=self.device),
})
"w",
{
"weight": torch.randint(
256,
(hidden_size, hidden_size // 2),
dtype=FP4_DTYPE,
device=self.device,
),
"wscale_swizzled": torch.randn(hidden_size, hidden_size // 16).to(
dtype=FP8_DTYPE, device=self.device
),
"wscale": torch.tensor([500], dtype=torch.float32, device=self.device),
"scale": torch.tensor([0.002], dtype=torch.float32, device=self.device),
},
)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v)
quant_output, output_block_scale = scaled_fp4_quant(
attn_output, 1 / self.w["scale"])
return cutlass_scaled_fp4_mm(a=quant_output,
b=self.w["weight"],
block_scale_a=output_block_scale,
block_scale_b=self.w["wscale_swizzled"],
alpha=self.w["scale"] * self.w["wscale"],
out_dtype=attn_output.dtype)
attn_output, 1 / self.w["scale"]
)
return cutlass_scaled_fp4_mm(
a=quant_output,
b=self.w["weight"],
block_scale_a=output_block_scale,
block_scale_b=self.w["wscale_swizzled"],
alpha=self.w["scale"] * self.w["wscale"],
out_dtype=attn_output.dtype,
)
if current_platform.is_cuda():
MODELS = [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
TestAttentionFp8StaticQuantPatternModel),
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
TestAttentionNvfp4QuantPatternModel)]
MODELS = [
(
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
TestAttentionFp8StaticQuantPatternModel,
),
(
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
TestAttentionNvfp4QuantPatternModel,
),
]
HEADS = [(64, 8), (40, 8)]
elif current_platform.is_rocm():
MODELS = [("amd/Llama-3.1-8B-Instruct-FP8-KV",
TestAttentionFp8StaticQuantPatternModel)]
MODELS = [
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
]
HEADS = [(32, 8), (40, 8)]
else:
MODELS = []
@@ -331,41 +368,53 @@ else:
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("batch_size",
[7, 256, 533] if current_platform.is_cuda() else [8])
@pytest.mark.parametrize(
"batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("model_name, model_class", MODELS)
@pytest.mark.parametrize("backend",
[_Backend.FLASHINFER] if current_platform.is_cuda()
else [_Backend.TRITON_ATTN])
@pytest.mark.parametrize(
"split_attention",
[False, True] if current_platform.is_rocm() else [False])
"backend",
[_Backend.FLASHINFER] if current_platform.is_cuda() else [_Backend.TRITON_ATTN],
)
@pytest.mark.parametrize(
"split_attention", [False, True] if current_platform.is_rocm() else [False]
)
# TODO(boyuan): test inductor graph partition on rocm
@pytest.mark.parametrize(
"use_inductor_graph_partition",
[False] if current_platform.is_rocm() else [False, True])
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test ROCm or CUDA")
[False] if current_platform.is_rocm() else [False, True],
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(current_platform.is_cuda()
and not current_platform.is_device_capability((10, 0)),
reason="On CUDA only test on SM100(Blackwell)")
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test ROCm or CUDA")
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
head_size: int, batch_size: int,
dtype: torch.dtype, model_name: str,
model_class: type[AttentionQuantPatternModel],
backend: _Backend, split_attention: bool,
use_inductor_graph_partition: bool,
monkeypatch, dist_init, caplog_vllm):
@pytest.mark.skipif(
current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)),
reason="On CUDA only test on SM100(Blackwell)",
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
def test_attention_quant_pattern(
num_qo_heads: int,
num_kv_heads: int,
head_size: int,
batch_size: int,
dtype: torch.dtype,
model_name: str,
model_class: type[AttentionQuantPatternModel],
backend: _Backend,
split_attention: bool,
use_inductor_graph_partition: bool,
monkeypatch,
dist_init,
caplog_vllm,
):
"""Test AttentionStaticQuantPattern fusion pass"""
if use_inductor_graph_partition and not is_torch_equal_or_newer(
"2.9.0.dev"):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
monkeypatch.setenv("VLLM_USE_V1", "1")
if split_attention:
@@ -386,21 +435,13 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
custom_ops=["+quant_fp8"],
use_inductor_graph_partition=use_inductor_graph_partition,
),
cache_config=CacheConfig(cache_dtype="fp8"))
cache_config=CacheConfig(cache_dtype="fp8"),
)
# Create test inputs
q = torch.randn(batch_size,
num_qo_heads * head_size,
dtype=dtype,
device=device)
k = torch.randn(batch_size,
num_kv_heads * head_size,
dtype=dtype,
device=device)
v = torch.randn(batch_size,
num_kv_heads * head_size,
dtype=dtype,
device=device)
q = torch.randn(batch_size, num_qo_heads * head_size, dtype=dtype, device=device)
k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
v = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
# Mark first dimension as dynamic for realistic testing
torch._dynamo.mark_dynamic(q, 0)
@@ -409,42 +450,53 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
# Run model directly without compilation and fusion
vllm_config_unfused = copy.deepcopy(vllm_config)
with set_current_vllm_config(vllm_config_unfused), set_forward_context(
attn_metadata=None, vllm_config=vllm_config_unfused
), global_force_attn_backend_context_manager(backend):
model_unfused = model_class(num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config_unfused)
with (
set_current_vllm_config(vllm_config_unfused),
set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
global_force_attn_backend_context_manager(backend),
):
model_unfused = model_class(
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config_unfused,
)
model_unfused = model_unfused.to(device)
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
batch_size, use_hnd=split_attention)
batch_size, use_hnd=split_attention
)
# Run model directly without compilation and fusion
result_unfused = model_unfused(q, k, v)
# Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig(
enable_attn_fusion=True, enable_noop=True)
with set_current_vllm_config(vllm_config), set_forward_context(
attn_metadata=None, vllm_config=vllm_config
), global_force_attn_backend_context_manager(backend):
model_fused = model_class(num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config,
w=model_unfused.w)
enable_attn_fusion=True, enable_noop=True
)
with (
set_current_vllm_config(vllm_config),
set_forward_context(attn_metadata=None, vllm_config=vllm_config),
global_force_attn_backend_context_manager(backend),
):
model_fused = model_class(
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config,
w=model_unfused.w,
)
model_fused = model_fused.to(device)
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_fused.build_attn_metadata(
batch_size, use_hnd=split_attention)
batch_size, use_hnd=split_attention
)
# Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config)
@@ -454,9 +506,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
# Compile model with fusion enabled
model_compiled = torch.compile(model_fused,
backend=test_backend,
fullgraph=True)
model_compiled = torch.compile(
model_fused, backend=test_backend, fullgraph=True
)
assert model_compiled.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v)
@@ -471,49 +523,49 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
assert model_compiled.attn._o_scale_float is not None
torch.testing.assert_close(result_unfused,
result_fused_2,
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(
result_unfused, result_fused_2, atol=1e-2, rtol=1e-2
)
# Check attn fusion support
quant_key = model_class.quant_key
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key) for key, layer in
vllm_config.compilation_config.static_forward_context.items()
layer.impl.fused_output_quant_supported(quant_key)
for key, layer in vllm_config.compilation_config.static_forward_context.items()
]
if any(attn_fusion_supported):
# Check quantization ops in the graph before and after fusion
test_backend.check_before_ops([QUANT_OPS[quant_key]],
fully_replaced=True)
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True)
# access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
# Check attention ops in the graph before and after fusion
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP,
test_backend.graph_post_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP, test_backend.graph_post_pass))
assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
assert len(attn_nodes_pre) == len(attn_nodes_post), \
assert len(attn_nodes_pre) == len(attn_nodes_post), (
"Should have same number of attention nodes before and after fusion"
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \
)
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, (
"Attention should not have output_scale before fusion"
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \
)
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, (
"Attention should have output_scale after fusion"
)
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
"Attention should not have output_block_scale before fusion"
)
if quant_key.dtype == FP8_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
"Attention should not have output_block_scale after FP8 fusion"
)
elif quant_key.dtype == FP4_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \
"Attention should have output_block_scale after FP4 fusion" # noqa: E501
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, (
"Attention should have output_block_scale after FP4 fusion"
) # noqa: E501
# Check that results are close
torch.testing.assert_close(result_unfused,
result_fused_1,
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2)

View File

@@ -6,14 +6,12 @@ import torch
import vllm
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
from .backend import TestBackend
@pytest.mark.parametrize("dtype",
[torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("num_tokens", [256, 1024])
@pytest.mark.parametrize("hidden_size", [64, 4096])
def test_noop_elimination(dtype, num_tokens, hidden_size):
@@ -22,7 +20,6 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
torch.manual_seed(1)
class Model(torch.nn.Module):
def forward(self, x):
# Chain of reshapes
y = x.reshape(-1, 128, 32)
@@ -32,7 +29,7 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
# Final reshape that should remain
b = a.reshape(-1, 128, 32)
# No-op slice
c = b[0:b.shape[0]]
c = b[0 : b.shape[0]]
# The pass should replace the result of this op with `c`.
d = torch.slice_scatter(
torch.ones_like(c), # Dummy tensor to be scattered into
@@ -43,10 +40,12 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
)
return d
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
pass_config=PassConfig(enable_noop=True),
))
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
pass_config=PassConfig(enable_noop=True),
)
)
with vllm.config.set_current_vllm_config(vllm_config):
noop_pass = NoOpEliminationPass(vllm_config)
@@ -82,17 +81,18 @@ def test_non_noop_slice_preserved():
x = torch.randn(16, 16)
class SliceModel(torch.nn.Module):
def forward(self, x):
base = x.clone()
src = torch.ones(15, 16)
y = torch.slice_scatter(base, src, dim=0, start=0, end=-1)
return x[0:-1, :], y
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
pass_config=PassConfig(enable_noop=True),
))
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
pass_config=PassConfig(enable_noop=True),
)
)
with vllm.config.set_current_vllm_config(vllm_config):
noop_pass = NoOpEliminationPass(vllm_config)
backend = TestBackend(noop_pass)

View File

@@ -28,7 +28,6 @@ def test_bad_callable():
# Pass that inherits from InductorPass
class ProperPass(InductorPass):
def __call__(self, graph: torch.fx.graph.Graph) -> None:
pass
@@ -39,8 +38,7 @@ class ProperPass(InductorPass):
ProperPass(),
# Can also wrap callables in CallableInductorPass for compliance
CallableInductorPass(simple_callable),
CallableInductorPass(simple_callable,
InductorPass.hash_source(__file__))
CallableInductorPass(simple_callable, InductorPass.hash_source(__file__)),
],
)
def test_pass_manager_uuid(callable):
@@ -65,8 +63,9 @@ def test_pass_manager_uuid(callable):
# UUID should be different due to config change
config2 = copy.deepcopy(config)
config2.compilation_config.pass_config.enable_fusion = not \
config2.compilation_config.pass_config.enable_fusion
config2.compilation_config.pass_config.enable_fusion = (
not config2.compilation_config.pass_config.enable_fusion
)
pass_manager3 = PostGradPassManager()
pass_manager3.configure(config2)
pass_manager3.add(callable)

View File

@@ -12,14 +12,20 @@ from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
PassConfig, VllmConfig)
from vllm.config import (
CompilationConfig,
DeviceConfig,
ModelConfig,
PassConfig,
VllmConfig,
)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
@@ -36,16 +42,15 @@ prompts = [
class TestModel(torch.nn.Module):
def __init__(self,
hidden_size=16,
intermediate_size=32,
vllm_config: VllmConfig = None):
def __init__(
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size)))
torch.empty((intermediate_size, hidden_size))
)
self.norm = RMSNorm(intermediate_size, 1e-05)
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
@@ -53,18 +58,18 @@ class TestModel(torch.nn.Module):
def forward(self, hidden_states, residual):
"""
Forward pass implementing the operations in the FX graph
Args:
hidden_states: Input tensor
residual: Residual tensor from previous layer
Returns:
Tuple containing the output tensor
"""
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)
#matrix multiplication
# matrix multiplication
permute = self.gate_proj.permute(1, 0)
mm = torch.mm(view, permute)
@@ -82,7 +87,7 @@ class TestModel(torch.nn.Module):
def ops_in_model_after(self):
return [
torch.ops.vllm.reduce_scatter.default,
torch.ops.vllm.all_gather.default
torch.ops.vllm.all_gather.default,
]
def ops_in_model(self):
@@ -90,18 +95,16 @@ class TestModel(torch.nn.Module):
class TestQuantModel(torch.nn.Module):
def __init__(self,
hidden_size=16,
intermediate_size=32,
vllm_config: VllmConfig = None):
def __init__(
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.vllm_config = vllm_config
self.gate_proj = torch.nn.Parameter(torch.empty(
(intermediate_size, hidden_size)),
requires_grad=False)
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size)), requires_grad=False
)
self.norm = RMSNorm(intermediate_size, 1e-05)
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
@@ -111,25 +114,24 @@ class TestQuantModel(torch.nn.Module):
self.scale = torch.rand(1, dtype=torch.float32)
# Create a weight that is compatible with torch._scaled_mm,
# which expects a column-major layout.
self.w = torch.rand(hidden_size,
intermediate_size).to(dtype=FP8_DTYPE).t()
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
self.wscale = torch.rand(1, dtype=torch.float32)
def forward(self, hidden_states, residual):
"""
Forward pass implementing the operations in the FX graph
Args:
hidden_states: Input tensor
residual: Residual tensor from previous layer
Returns:
Tuple containing the output tensor
"""
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)
#matrix multiplication
# matrix multiplication
permute = self.gate_proj.permute(1, 0)
mm = torch.mm(view, permute)
@@ -140,45 +142,51 @@ class TestQuantModel(torch.nn.Module):
norm_output, residual_output = self.norm(all_reduce, residual)
# scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(norm_output,
self.w,
self.wscale,
input_scale=self.scale.to(
norm_output.device))
fp8_linear_result = self.fp8_linear.apply(
norm_output,
self.w,
self.wscale,
input_scale=self.scale.to(norm_output.device),
)
return fp8_linear_result, residual_output
def ops_in_model_before(self):
ops_to_remove = [torch.ops.vllm.all_reduce.default
] # Always removed by SP
ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP
# The following are only removed if fusion happens
if self.vllm_config and self.vllm_config.compilation_config \
.pass_config.enable_fusion:
ops_to_remove.extend([
torch.ops._C.fused_add_rms_norm.default,
torch.ops._C.static_scaled_fp8_quant.default,
])
if (
self.vllm_config
and self.vllm_config.compilation_config.pass_config.enable_fusion
):
ops_to_remove.extend(
[
torch.ops._C.fused_add_rms_norm.default,
torch.ops._C.static_scaled_fp8_quant.default,
]
)
return ops_to_remove
def ops_in_model_after(self):
ops_to_add = [
torch.ops.vllm.reduce_scatter.default,
torch.ops.vllm.all_gather.default
torch.ops.vllm.all_gather.default,
]
# The following is only added if fusion happens
if self.vllm_config and self.vllm_config.compilation_config \
.pass_config.enable_fusion:
ops_to_add.append(
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
if (
self.vllm_config
and self.vllm_config.compilation_config.pass_config.enable_fusion
):
ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
return ops_to_add
def ops_in_model(self):
if self.vllm_config and self.vllm_config.compilation_config \
.pass_config.enable_fusion:
if (
self.vllm_config
and self.vllm_config.compilation_config.pass_config.enable_fusion
):
# If fusion happens, the fused op is the one
# we check for (de)functionalization
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
] # noqa: E501
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] # noqa: E501
else:
# If no fusion, the original ops are checked
return [
@@ -195,30 +203,47 @@ class TestQuantModel(torch.nn.Module):
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("enable_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_sequence_parallelism_pass(test_model_cls: type[torch.nn.Module],
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype,
enable_fusion: bool):
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_sequence_parallelism_pass(
test_model_cls: type[torch.nn.Module],
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_fusion: bool,
):
num_processes = 2
def run_torch_spawn(fn, nprocs):
# need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda
torch.multiprocessing.spawn(fn,
args=(num_processes, test_model_cls,
batch_size, seq_len, hidden_size,
dtype, enable_fusion),
nprocs=nprocs)
torch.multiprocessing.spawn(
fn,
args=(
num_processes,
test_model_cls,
batch_size,
seq_len,
hidden_size,
dtype,
enable_fusion,
),
nprocs=nprocs,
)
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
def sequence_parallelism_pass_on_test_model(
local_rank: int, world_size: int,
test_model_cls: type[torch.nn.Module], batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype, enable_fusion: bool):
local_rank: int,
world_size: int,
test_model_cls: type[torch.nn.Module],
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_fusion: bool,
):
current_platform.seed_everything(0)
device = torch.device(f"cuda:{local_rank}")
@@ -226,13 +251,15 @@ def sequence_parallelism_pass_on_test_model(
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})
update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
}
)
# initialize distributed
init_distributed_environment()
@@ -240,27 +267,28 @@ def sequence_parallelism_pass_on_test_model(
# configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
enable_sequence_parallelism=True,
enable_fusion=enable_fusion,
enable_noop=True)) # NoOp needed for fusion
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(
enable_sequence_parallelism=True,
enable_fusion=enable_fusion,
enable_noop=True,
)
) # NoOp needed for fusion
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config.model_config = ModelConfig(model=model_name,
trust_remote_code=True,
dtype=dtype,
seed=42)
vllm_config.model_config = ModelConfig(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
)
noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
passes_for_backend: list[VllmInductorPass] = \
[noop_pass, sequence_parallelism_pass]
passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass]
if enable_fusion:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
@@ -271,12 +299,9 @@ def sequence_parallelism_pass_on_test_model(
backend_no_func = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass)
model = test_model_cls(hidden_size,
hidden_size * 2,
vllm_config=vllm_config)
model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config)
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype)
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
@@ -297,8 +322,7 @@ def sequence_parallelism_pass_on_test_model(
# check if the functionalization pass is applied
for op in model.ops_in_model():
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
op) is None # noqa: E501
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # noqa: E501
# make sure the ops were all de-functionalized
found = dict()

View File

@@ -8,10 +8,15 @@ import torch
import vllm.envs as envs
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
# yapf conflicts with isort for this block
# yapf: disable
from vllm.compilation.activation_quant_fusion import (
FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass)
FUSED_OPS,
SILU_MUL_OP,
ActivationQuantFusionPass,
)
# yapf: enable
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass
@@ -19,9 +24,14 @@ from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, kFp8StaticTensorSym, kNvfp4Quant)
GroupShape,
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, cutlass_fp8_supported)
Fp8LinearOp,
cutlass_fp8_supported,
)
from vllm.platforms import current_platform
from ..utils import override_cutlass_fp8_supported
@@ -36,7 +46,6 @@ def is_nvfp4_supported():
class TestSiluMulFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
@@ -53,10 +62,7 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
def forward(self, x):
y = self.silu_and_mul(x)
x2 = self.fp8_linear.apply(y,
self.w,
self.wscale,
input_scale=self.wscale)
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
return x2
def ops_in_model_before(self):
@@ -67,11 +73,12 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
class TestSiluMulNvfp4QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
super().__init__()
from vllm.compilation.activation_quant_fusion import (
silu_and_mul_nvfp4_quant_supported)
silu_and_mul_nvfp4_quant_supported,
)
assert silu_and_mul_nvfp4_quant_supported
self.silu_and_mul = SiluAndMul()
@@ -88,12 +95,14 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
def forward(self, x):
y = self.silu_and_mul(x)
y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale)
out = cutlass_scaled_fp4_mm(a=y_quant,
b=self.w,
block_scale_a=y_block_scale,
block_scale_b=self.w_block_scale,
alpha=self.alpha,
out_dtype=y.dtype)
out = cutlass_scaled_fp4_mm(
a=y_quant,
b=self.w,
block_scale_a=y_block_scale,
block_scale_b=self.w_block_scale,
alpha=self.alpha,
out_dtype=y.dtype,
)
return out
def ops_in_model_before(self):
@@ -108,16 +117,24 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
"model_class",
cast(list[type], [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
if is_nvfp4_supported() else [TestSiluMulFp8QuantModel]))
cast(
list[type],
[TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
if is_nvfp4_supported()
else [TestSiluMulFp8QuantModel],
),
)
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch",
[True, False] if cutlass_fp8_supported() else [True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
cuda_force_torch):
@pytest.mark.parametrize(
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
)
@pytest.mark.skipif(
envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
)
def test_fusion_silu_and_mul_quant(
num_tokens, hidden_size, dtype, model_class, cuda_force_torch
):
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
pytest.skip("Duplicate tests for NVFP4")
@@ -129,17 +146,13 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
# Reshape pass is needed for the fusion pass to work
config = VllmConfig()
config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
pass_config=PassConfig(enable_fusion=True, enable_noop=True)
)
fusion_pass = ActivationQuantFusionPass(config)
passes = [
NoOpEliminationPass(config), fusion_pass,
PostCleanupPass(config)
]
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
backend = TestBackend(*passes)
model = model_class(hidden_size=hidden_size,
cuda_force_torch=cuda_force_torch,
x=x)
model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x)
# First dimension dynamic
torch._dynamo.mark_dynamic(x, 0)
@@ -155,10 +168,9 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
elif model_class == TestSiluMulNvfp4QuantModel:
atol, rtol = 1e-1, 1e-1
torch.testing.assert_close(result[0].to(dtype=dtype),
result2[0].to(dtype=dtype),
atol=atol,
rtol=rtol)
torch.testing.assert_close(
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
)
assert fusion_pass.matched_count == 1

View File

@@ -10,7 +10,6 @@ from vllm.config import CompilationLevel
class MyMod(torch.nn.Module):
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
if cache is not None:
return x + cache
@@ -18,12 +17,12 @@ class MyMod(torch.nn.Module):
class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
def __init__(self, model):
self.model = model
compiled_callable = torch.compile(self.forward, backend="eager")
super().__init__(compiled_callable,
compilation_level=CompilationLevel.DYNAMO_ONCE)
super().__init__(
compiled_callable, compilation_level=CompilationLevel.DYNAMO_ONCE
)
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
# this is the function to be compiled
@@ -54,10 +53,8 @@ def test_torch_compile_wrapper():
# for new input, dispatch to the compiled code directly
new_x = torch.tensor([3])
assert wrapper(new_x,
None).item() == 6 # dispatch to the first compiled code
assert wrapper(
new_x, cache).item() == 5 # dispatch to the second compiled code
assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code
assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code
for wrapper in wrappers:
# make sure they have independent compiled codes