[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass (#16756)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -7,8 +7,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
import vllm.plugins
|
||||
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
||||
FusionPass, QuantKey)
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
||||
FusionPass, GroupShape, QuantKey)
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
||||
VllmConfig)
|
||||
@@ -30,9 +29,10 @@ class TestModel(torch.nn.Module):
|
||||
self.cutlass_fp8_enabled = cutlass_fp8_enabled
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
self.key = QuantKey(dtype=FP8_DTYPE,
|
||||
static=static,
|
||||
per_tensor=static,
|
||||
group_shape=group_shape,
|
||||
symmetric=True)
|
||||
if static:
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
@@ -122,9 +122,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
||||
|
||||
# In pre-nodes, fp8 quant should be there and fused kernels should not
|
||||
backend.check_before_ops(model.ops_in_model_before(), find_auto_fn,
|
||||
find_auto_fn_maybe)
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
|
||||
# In post-nodes, fused kernels should be there and fp8 quant should not
|
||||
backend.check_after_ops(model.ops_in_model_after(), find_auto_fn,
|
||||
find_auto_fn_maybe)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
Reference in New Issue
Block a user