Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -10,10 +10,14 @@ from torch import nn
|
||||
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
||||
support_torch_compile)
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
@@ -27,12 +31,7 @@ RANDOM_SEED = 0
|
||||
|
||||
@support_torch_compile
|
||||
class ParentModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -40,7 +39,6 @@ class ParentModel(nn.Module):
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(self, mlp_size: int, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False)
|
||||
@@ -51,17 +49,21 @@ class Attention(nn.Module):
|
||||
nn.init.xavier_normal_(
|
||||
self.pre_attn.weight.data,
|
||||
generator=torch.Generator().manual_seed(RANDOM_SEED),
|
||||
gain=0.001)
|
||||
gain=0.001,
|
||||
)
|
||||
nn.init.xavier_normal_(
|
||||
self.post_attn.weight.data,
|
||||
generator=torch.Generator().manual_seed(RANDOM_SEED),
|
||||
gain=0.001)
|
||||
gain=0.001,
|
||||
)
|
||||
|
||||
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_f32 = x.float()
|
||||
return (x_f32 * torch.rsqrt(
|
||||
torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) *
|
||||
self.rms_norm_weight).to(x.dtype)
|
||||
return (
|
||||
x_f32
|
||||
* torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6)
|
||||
* self.rms_norm_weight
|
||||
).to(x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.pre_attn(x)
|
||||
@@ -76,14 +78,15 @@ class Attention(nn.Module):
|
||||
|
||||
@support_torch_compile
|
||||
class CompiledAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
mlp_size: int,
|
||||
hidden_size: int,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mlp_size: int,
|
||||
hidden_size: int,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.attn = Attention(mlp_size, hidden_size)
|
||||
|
||||
@@ -93,21 +96,21 @@ class CompiledAttention(nn.Module):
|
||||
|
||||
@support_torch_compile
|
||||
class CompiledAttentionTwo(CompiledAttention):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.attn(x) + x
|
||||
|
||||
|
||||
@ignore_torch_compile
|
||||
class SimpleModelWithTwoGraphs(ParentModel):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
mlp_size: int,
|
||||
hidden_size: int,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mlp_size: int,
|
||||
hidden_size: int,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
# Test will fail without set_model_tag here with error:
|
||||
# "ValueError: too many values to unpack (expected 3)"
|
||||
@@ -142,32 +145,45 @@ class SimpleModelWithTwoGraphs(ParentModel):
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor,
|
||||
cudagraph_runtime_mode: CUDAGraphMode):
|
||||
def run_model(
|
||||
vllm_config: VllmConfig,
|
||||
model: nn.Module,
|
||||
inputs: torch.Tensor,
|
||||
cudagraph_runtime_mode: CUDAGraphMode,
|
||||
):
|
||||
with set_forward_context({}, vllm_config=vllm_config):
|
||||
# warmup for the model with cudagraph_mode NONE
|
||||
model(inputs)
|
||||
|
||||
# simulate cudagraphs capturing
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
model(inputs[:2])
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1, )):
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1,
|
||||
),
|
||||
):
|
||||
model(inputs[:1])
|
||||
|
||||
# simulate cudagraphs replay
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
with set_forward_context(
|
||||
{},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2,
|
||||
),
|
||||
):
|
||||
output = model(inputs[:2])
|
||||
|
||||
output = output.cpu()
|
||||
@@ -178,82 +194,104 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
outputs = []
|
||||
|
||||
# piecewise compile
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix='').eval().cuda()
|
||||
model = (
|
||||
SimpleModelWithTwoGraphs(
|
||||
mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix="",
|
||||
)
|
||||
.eval()
|
||||
.cuda()
|
||||
)
|
||||
|
||||
# Pre-allocate memory for CUDAGraph which expects
|
||||
# static tensor addresses
|
||||
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=2, # two graphs for the model
|
||||
num_piecewise_graphs_seen=6,
|
||||
# attn_one, attn_two each has 3 piecewise graphs
|
||||
# (pre attn, post attn, silly_attention) each
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
# attn_one, attn_two has pre attn and post attn each, total=4
|
||||
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
num_graphs_seen=2, # two graphs for the model
|
||||
num_piecewise_graphs_seen=6,
|
||||
# attn_one, attn_two each has 3 piecewise graphs
|
||||
# (pre attn, post attn, silly_attention) each
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
# attn_one, attn_two has pre attn and post attn each, total=4
|
||||
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
outputs.append(
|
||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
# no compile or cudagraph
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION, ))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION,
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix='').eval().cuda()
|
||||
model = (
|
||||
SimpleModelWithTwoGraphs(
|
||||
mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix="",
|
||||
)
|
||||
.eval()
|
||||
.cuda()
|
||||
)
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
):
|
||||
outputs.append(
|
||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
# piecewise compile without CUDA graph
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=False,
|
||||
splitting_ops=["silly.attention"],
|
||||
))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=False,
|
||||
splitting_ops=["silly.attention"],
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix='').eval().cuda()
|
||||
model = (
|
||||
SimpleModelWithTwoGraphs(
|
||||
mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
vllm_config=vllm_config,
|
||||
prefix="",
|
||||
)
|
||||
.eval()
|
||||
.cuda()
|
||||
)
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=2,
|
||||
num_piecewise_graphs_seen=6,
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_backend_compilations=4,
|
||||
num_cudagraph_captured=0, # no cudagraph captured
|
||||
num_graphs_seen=2,
|
||||
num_piecewise_graphs_seen=6,
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_backend_compilations=4,
|
||||
num_cudagraph_captured=0, # no cudagraph captured
|
||||
):
|
||||
outputs.append(
|
||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
# Generally don't expect outputs with and without inductor
|
||||
# to be bitwise equivalent
|
||||
|
||||
Reference in New Issue
Block a user